diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7a3780b --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +.vscode +data \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d1ae5a9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Inria + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 0aa1541..60b9c88 100644 --- a/README.md +++ b/README.md @@ -1,34 +1,261 @@
-## Material Palette: Extraction of Materials from a Single Image - +## Material Palette: Extraction of Materials from a Single Image (CVPR 2024) +
Ivan Lopes1   Fabio Pizzati2   Raoul de Charette1
1 Inria, - 2 University of Oxford + 2 Oxford Uni.

-[![arXiv](https://img.shields.io/badge/arXiv-2311.17060-darkred?style=flat-square&logo=arxiv)](https://arxiv.org/abs/2311.17060) -[![Project page](https://img.shields.io/badge/page-Material_Palette-darkgreen?style=flat-square&logo=github)](https://astra-vision.github.io/MaterialPalette/) +[![arXiv](https://img.shields.io/badge/arXiv-_-darkgreen?style=flat-square&logo=arxiv)](https://arxiv.org/abs/2311.17060) +[![Project page](https://img.shields.io/badge/πŸš€_Project_Page-_-darkgreen?style=flat-square)](https://astra-vision.github.io/MaterialPalette/) +[![cvf](https://img.shields.io/badge/CVPR_2024-_-darkgreen?style=flat-square)](https://cvpr.thecvf.com/Conferences/2024/AcceptedPapers#:~:text=Material%20Palette%3A%20Extraction%20of%20Materials%20from%20a%20Single%20Image) +[![dataset](https://img.shields.io/badge/πŸ€—_dataset-(soon)-darkred?style=flat-square)](#) +[![star](https://img.shields.io/badge/⭐_star--darkgreen?style=flat-square)](https://github.com/astra-vision/MaterialPalette/stargazers) + + +TL;DR, Material Palette extracts a palette of PBR materials -
albedo, normals, and roughness - from a single real-world image. -The code source will be published soon, stay tuned for more! [![arXiv](https://img.shields.io/badge/-⭐-black?style=flat-square)](https://github.com/astra-vision/MaterialPalette/stargazers)
-https://github.com/astra-vision/MaterialPalette/assets/30524163/44e45e58-7c7d-49a3-8b6e-ec6b99cf9c62 +https://github.com/wonjunior/readme/assets/30524163/1f5d8d3f-e7f0-449e-8230-bd076cca7884 + + + +* [Overview](#1-overview) +* [1. Installation](#1-installation) +* [2. Quick Start](#2-quick-start) + * [Generation](#-generation) + * [Complete Pipeline](#-complete-pipeline) +* [3. Project Structure](#3-project-structure) +* [4. (optional) Retraining](#4-optional-training) +* [5. Acknowledgments](#5-acknowledgments) +* [6. Licence](#6-licence) + + +## Todo + +- 🚨 Release of the TexSD dataset! +- 🚨 3D rendering script. + +## Overview + +This is the official repository of [**Material Palette**](https://astra-vision.github.io/MaterialPalette/). In a nutshell, the method works in three stages: first, concepts are extracted from an input image based on a user provided mask; then, those concepts are used to generate texture images; finally the generations are decomposed into SVBRDF maps (albedo, normals, and roughness). Visit our project page of consult our paper for more details! + +![pipeline](https://github.com/astra-vision/MaterialPalette/assets/30524163/be03b0ca-bee2-4fc7-bebd-9519c3c4947d) + +**Content**: This repository allows the extraction of texture concepts from image and region mask sets. It also allows generation at different resolutions. Finally it proposes a decomposition step thanks to our decomposition model, for which we share the training weights. + +> [!TIP] +> We propose a ["Quick Start"](#2-quick-start) section: before diving straight into the full pipeline, we share four pretrained concepts ⚑ so you can go ahead and experiment with the texture generation step of the method: see ["Β§ Generation"](#-generation). Then you can try out the full method with your own image and masks = concept learning + generation + decomposition, see ["Β§ Complete Pipeline"](#-complete-pipeline). + + +## 1. Installation + + 1. Download the source code with git + ``` + git clone https://github.com/astra-vision/MaterialPalette.git + ``` + The repo can also be download as a zip [here](https://github.com/astra-vision/MaterialPalette/archive/refs/heads/master.zip). + + 2. Create a conda environment with the dependencies. + ``` + conda env create --verbose -f deps.yml + ``` + This repo was tested with [**Python**](https://www.python.org/doc/versions/) 3.10.8, [**PyTorch**](https://pytorch.org/get-started/previous-versions/) 1.13, [**diffusers**](https://huggingface.co/docs/diffusers/installation) 0.19.3, [**peft**](https://huggingface.co/docs/peft/en/install) 0.5, and [**PyTorch Lightning**](https://lightning.ai/docs/pytorch/stable/past_versions.html) 1.8.3. + + 3. Load the conda environment: + ``` + conda activate matpal + ``` + + 4. If you are looking to perform decomposition, download our pretrained model: + ``` + wget https://github.com/astra-vision/MaterialPalette/archive/refs/heads/model.ckpt + ``` + This is not required if your are only looking to perform texture extraction + + +## 2. Quick start + +Here are instructions to get you started using **Material Palette**. First, we provide some optimized concepts so you can experiment with the generation pipeline. We also show how to run the method on user selected images and masks (concept learning + generation + decomposition) + +### Β§ Generation + +| Input image | 1K | 2K | 4K | 8K | ⬇️ LoRA ~8Kb +| :-: | :-: | :-: | :-: | :-: | :-: | +| J | J | J | J | J | [![x](https://img.shields.io/badge/-⚑blue_tiles.zip-black)](https://github.com/astra-vision/MaterialPalette/files/14601640/blue_tiles.zip) +| J | J | J | J | J | [![x](https://img.shields.io/badge/-⚑cat_fur.zip-black)](https://github.com/astra-vision/MaterialPalette/files/14601641/cat_fur.zip) +| J | J | J | J | J | [![x](https://img.shields.io/badge/-⚑damaged.zip-black)](https://github.com/astra-vision/MaterialPalette/files/14601642/damaged.zip) +| J | J | J | J | J | [![x](https://img.shields.io/badge/-⚑ivy_bricks.zip-black)](https://github.com/astra-vision/MaterialPalette/files/14601643/ivy_bricks.zip) + +All generations were downscaled for memory constraints. + + +Go ahead and download one of the above LoRA concept checkpoints, example for "blue_tiles": + +``` +wget https://github.com/astra-vision/MaterialPalette/files/14601640/blue_tiles.zip; +unzip blue_tiles.zip +``` +To generate from a checkpoint, use the [`concept`](./concept/) module either via the command line interface or the functional interface in python: +- ![](https://img.shields.io/badge/$-command_line-white?style=flat-square) + ``` + python concept/infer.py path/to/LoRA/checkpoint + ``` +- ![](https://img.shields.io/badge/-python-white?style=flat-square&logo=python) + ``` + import concept + concept.infer(path_to_LoRA_checkpoint) + ``` + +Results will be placed relative to the checkpoint directory in a `outputs` folder. + +You have control over the following parameters: +- `stitch_mode`: concatenation, average, or weighted average (*default*); +- `resolution`: the output resolution of the generated texture; +- `prompt`: one of the four prompt templates: + - `"p1"`: `"top view realistic texture of S*"`, + - `"p2"`: `"top view realistic S* texture"`, + - `"p3"`: `"high resolution realistic S* texture in top view"`, + - `"p4"`: `"realistic S* texture in top view"`; +- `seed`: inference seed when sampling noise; +- `renorm`: whether or not to renormalize the generated samples generations based on input image (this option can only be used when called from inside the pipeline, *ie.* when the input image is available); +- `num_inference_steps`: number of denoising steps. + +A complete list of parameters can be viewed with `python concept/infer.py --help` + + +### Β§ Complete Pipeline + +We provide an example (input image with user masks used for the pipeline figure). You can download it here: [**mansion.zip**](https://github.com/astra-vision/MaterialPalette/files/14619163/mansion.zip) (credits photograph: [Max Rahubovskiy](https://www.pexels.com/@heyho/)). + +To help you get started with your own images, you should follow this simple data structure: one folder per inverted image, inside should be the input image (`.jpg`, `.jpeg`, or `.png`) and a subdirectory named `masks` containing the different region masks as `.png` (these **must all have the same aspect ratio** as the rgb image). Here is an overview of our mansion example: +``` +β”œβ”€β”€ masks/ +β”‚ β”œβ”€β”€ wood.png +β”‚ β”œβ”€β”€ grass.png +β”‚ └── stone.png +└── mansion.jpg +``` + +|region|mask|overlay|generation|albedo|normals|roughness| +|:--:|:--:|:--:|:--:|:--:|:--:|:--:| +|![#6C8EBF](https://placehold.co/15x15/6C8EBF/6C8EBF.png) | J|J|J|J|J|J| +|![#EDB01A](https://placehold.co/15x15/EDB01A/EDB01A.png) | J|J|J|J|J|J| +|![#AA4A44](https://placehold.co/15x15/AA4A44/AA4A44.png) | J|J|J|J|J|J| + + + + +To invert and generate textures from a folder, use [`pipeline.py`](./pipeline.py): + +- ![](https://img.shields.io/badge/$-command_line-white?style=flat-square) + ``` + python pipeline.py path/to/folder + ``` + +Under the hood, it uses two modules: +1. [`concept`](./concept), to extract and generate the texture ([`concept.crop`](./concept/crop.py), [`concept.invert`](./concept/invert.py), and [`concept.infer`](./concept/infer.py)); +2. [`capture`](./capture/), to perform the BRDF decomposition. + +A minimal example is provided here: + +- ![](https://img.shields.io/badge/-python-white?style=flat-square&logo=python) + ``` + ## Extract square crops from image for each of the binary masks located in /masks + regions = concept.crop(args.path) + + ## Iterate through regions to invert the concept and generate texture views + for region in regions.iterdir(): + lora = concept.invert(region) + concept.infer(lora, renorm=True) + + ## Construct a dataset with all generations and load pretrained decomposition model + data = capture.get_data(predict_dir=args.path, predict_ds='sd') + module = capture.get_inference_module(pt='model.ckpt') + + ## Proceed with inference on decomposition model + decomp = Trainer(default_root_dir=args.path, accelerator='gpu', devices=1, precision=16) + decomp.predict(module, data) + ``` +To view options available for the concept learning, use ``PYTHONPATH=. python concept/invert.py --help`` + +> [!IMPORTANT] +> By default, both `train_text_encoder` and `gradient_checkpointing` are set to `True`. Also, this implementation does not include the `LPIPS` filter/ranking of the generations. The code will only output a single sample per region. You may experiment with different prompts and parameters (see ["Generation"](#-generation) section). + +## 3. Project structure + +The [`pipeline.py`](./pipeline.py) file is the entry point to run the whole pipeline on a folder containing the input image at its root and a `masks/` sub-directory containing all user defined masks. The [`train.py`](./train.py) file is used to train the decomposition model. The most important files are shown here: +``` +. +β”œβ”€β”€ capture/ % Module for decomposition +β”‚ β”œβ”€β”€ callbacks/ % Lightning trainer callbacks +β”‚ β”œβ”€β”€ data/ % Dataset, subsets, Lightning datamodules +β”‚ β”œβ”€β”€ render/ % 2D physics based renderer +β”‚ β”œβ”€β”€ utils/ % Utility functions +β”‚ └── source/ % Network, loss, and LightningModule +β”‚ └── routine.py % Training loop +β”‚ +└── concept/ % Module for inversion and texture generation + β”œβ”€β”€ crop.py % Square crop extraction from image and masks + β”œβ”€β”€ invert.py % Optimization code to learn the concept S* + └── infer.py % Inference code to generate texture from S* +``` +If you have any questions, post via the [*issues tracker*](https://github.com/astra-vision/MaterialPalette/issues) or contact the corresponding author. + +## 4. (optional) Training + +We provide the pretrained decomposition weights (see ["Installation"](#1-installation)). However, if you are looking to retrain the domain adaptive model for your own purposes, we provide the code to do so. Our method relies on the training of a multi-task network on labeled (real) and unlabeled (synthetic) images, *jointly*. In case you wish to retrain on the same datasets, you will have to download both the ***AmbientCG*** and ***TexSD*** datasets (⚠️ will be released soon). + +First download the PBR materials (source) dataset from [AmbientCG](https://ambientcg.com/): +``` +python capture/data/download.py path/to/target/directory +``` + +In order to run the training script, use: +``` +python train.py --config=path/to/yml/config +``` + +Additional options can be found with `python train.py --help`. + +> [!NOTE] +> The decomposition model allows estimating the pixel-wise BRDF maps from a single texture image input. + +## 5. Acknowledgments +This research project was mainly funded by the French Agence Nationale de la Recherche (ANR) as part of project SIGHT (ANR-20-CE23-0016). Fabio Pizzati was partially funded by KAUST (Grant DFR07910). Results where obtained using HPC resources from GENCI-IDRIS (Grant 2023-AD011014389). + +The repository contains code taken from [`PEFT`](https://github.com/huggingface/peft), [`SVBRDF-Estimation`](https://github.com/mworchel/svbrdf-estimation/tree/master), [`DenseMTL`](https://github.com/astra-vision/DenseMTL). As for visualization, we used [`DeepBump`](https://github.com/HugoTini/DeepBump) and [**Blender**](https://www.blender.org/). Credit to Runway for providing us all the [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) model weights. All images and 3D scenes used in this work have permissive licences. Special credits to [**AmbientCG**](https://ambientcg.com/list) for the huge work. + +Authors would also like to thank all members of the [Astra-Vision](https://astra-vision.github.io/) team for their valuable feedback. + +## 6. Licence +If you find this code useful, please cite our paper: +``` +@inproceedings{lopes2024material, author = {Lopes, Ivan and Pizzati, Fabio and de Charette, Raoul}, title = {Material Palette: Extraction of Materials from a Single Image}, - booktitle = {arXiv}, - year = {2023}, + booktitle = {CVPR}, + year = {2024}, project = {https://astra-vision.github.io/MaterialPalette/} } ``` +**Material Palette** is released under [MIT License](./LICENSE). + +--- -### Acknowledgments -This research project was funded by the French Agence Nationale de la Recherche (ANR) as part of project SIGHT (ANR-20-CE23-0016). It was performed using HPC resources from GENCI-IDRIS (Grant 2023-AD011014389). +[💁 jump to top](#material-palette-extraction-of-materials-from-a-single-image-cvpr-2024) \ No newline at end of file diff --git a/capture/__init__.py b/capture/__init__.py new file mode 100644 index 0000000..8b0e2fa --- /dev/null +++ b/capture/__init__.py @@ -0,0 +1,6 @@ + +from .utils.model import get_inference_module +from .utils.exp import get_data + + +__all__ = ['get_inference_module', 'get_data'] \ No newline at end of file diff --git a/capture/callbacks/__init__.py b/capture/callbacks/__init__.py new file mode 100644 index 0000000..51503b0 --- /dev/null +++ b/capture/callbacks/__init__.py @@ -0,0 +1,4 @@ +from .metrics import MetricLogging +from .visualize import VisualizeCallback + +__all__ = ['MetricLogging', 'VisualizeCallback'] \ No newline at end of file diff --git a/capture/callbacks/metrics.py b/capture/callbacks/metrics.py new file mode 100644 index 0000000..667924b --- /dev/null +++ b/capture/callbacks/metrics.py @@ -0,0 +1,28 @@ +import os +from pathlib import Path +from collections import OrderedDict + +from pytorch_lightning.callbacks import Callback + +from ..utils.log import append_csv, get_info + + +class MetricLogging(Callback): + def __init__(self, weights: str, test_list: str, outdir: Path): + super().__init__() + assert outdir.is_dir() + + self.weights = weights + self.test_list = test_list + self.outpath = outdir/'eval.csv' + + def on_test_end(self, trainer, pl_module): + weight_name, epoch = get_info(str(self.weights)) + *_, test_set = self.test_list.parts + + parsed = {k: f'{v}' for k,v in trainer.logged_metrics.items()} + + odict = OrderedDict(name=weight_name, epoch=epoch, test_set=test_set) + odict.update(parsed) + append_csv(self.outpath, odict) + print(f'logged metrics in: {self.outpath}') \ No newline at end of file diff --git a/capture/callbacks/visualize.py b/capture/callbacks/visualize.py new file mode 100644 index 0000000..92a3692 --- /dev/null +++ b/capture/callbacks/visualize.py @@ -0,0 +1,85 @@ +from pathlib import Path + +import torch +import pytorch_lightning as pl +from torchvision.utils import make_grid, save_image +from torchvision.transforms import Resize + +from capture.render import encode_as_unit_interval, gamma_encode + + +class VisualizeCallback(pl.Callback): + def __init__(self, exist_ok: bool, out_dir: Path, log_every_n_epoch: int, n_batches_shown: int): + super().__init__() + + self.out_dir = out_dir/'images' + if not exist_ok and (self.out_dir.is_dir() and len(list(self.out_dir.iterdir())) > 0): + print(f'directory {out_dir} already exists, press \'y\' to proceed') + x = input() + if x != 'y': + exit(1) + + self.out_dir.mkdir(parents=True, exist_ok=True) + + self.log_every_n_epoch = log_every_n_epoch + self.n_batches_shown = n_batches_shown + self.resize = Resize(size=[128,128], antialias=True) + + def setup(self, trainer, module, stage): + self.logger = trainer.logger + + def on_train_batch_end(self, *args): + self._on_batch_end(*args, split='train') + + def on_validation_batch_end(self, *args): + self._on_batch_end(*args, split='valid') + + def _on_batch_end(self, trainer, module, outputs, inputs, batch, *args, split): + x_src, x_tgt = inputs + + # optim_idx:0=discr & optim_idx:1=generator + y_src, y_tgt = outputs[1]['y'] if isinstance(outputs, list) else outputs['y'] + + epoch = trainer.current_epoch + if epoch % self.log_every_n_epoch == 0 and batch <= self.n_batches_shown: + if x_src and y_src: + self._visualize_src(x_src, y_src, split=split, epoch=epoch, batch=batch, ds='src') + if x_tgt and y_tgt: + self._visualize_tgt(x_tgt, y_tgt, split=split, epoch=epoch, batch=batch, ds='tgt') + + def _visualize_src(self, x, y, split, epoch, batch, ds): + zipped = zip(x.albedo, x.roughness, x.normals, x.displacement, x.input, x.image, + y.albedo, y.roughness, y.normals, y.displacement, y.reco, y.image) + + grid = [self._visualize_single_src(*z) for z in zipped] + + name = self.out_dir/f'{split}{epoch:05d}_{ds}_{batch}.jpg' + save_image(grid, name, nrow=1, padding=5) + + @torch.no_grad() + def _visualize_single_src(self, a, r, n, d, input, mv, a_p, r_p, n_p, d_p, reco, mv_p): + n = encode_as_unit_interval(n) + n_p = encode_as_unit_interval(n_p) + + mv_gt = [gamma_encode(o) for o in mv] + mv_pred = [gamma_encode(o) for o in mv_p] + reco = gamma_encode(reco) + + maps = [input, a, r, n, d] + mv_gt + [reco, a_p, r_p, n_p, d_p] + mv_pred + maps = [self.resize(x.cpu()) for x in maps] + return make_grid(maps, nrow=len(maps)//2, padding=0) + + def _visualize_tgt(self, x, y, split, epoch, batch, ds): + zipped = zip(x.input, y.albedo, y.roughness, y.normals, y.displacement) + + grid = [self._visualize_single_tgt(*z) for z in zipped] + + name = self.out_dir/f'{split}{epoch:05d}_{ds}_{batch}.jpg' + save_image(grid, name, nrow=1, padding=5) + + @torch.no_grad() + def _visualize_single_tgt(self, input, a_p, r_p, n_p, d_p): + n_p = encode_as_unit_interval(n_p) + maps = [input, a_p, r_p, n_p, d_p] + maps = [self.resize(x.cpu()) for x in maps] + return make_grid(maps, nrow=len(maps), padding=0) \ No newline at end of file diff --git a/capture/predict.yml b/capture/predict.yml new file mode 100644 index 0000000..a486487 --- /dev/null +++ b/capture/predict.yml @@ -0,0 +1,26 @@ +archi: densemtl +mode: predict +logger: + project: ae_acg +data: + batch_size: 1 + num_workers: 10 + input_size: 512 + predict_ds: sd + predict_list: data/matlist/pbrsd_v2 +trainer: + accelerator: gpu + devices: 1 + precision: 16 +routine: + lr: 2e-5 +loss: + use_source: True + use_target: False + reg_weight: 1 + render_weight: 1 + n_random_configs: 3 + n_symmetric_configs: 6 +viz: + n_batches_shown: 5 + log_every_n_epoch: 5 \ No newline at end of file diff --git a/capture/render/__init__.py b/capture/render/__init__.py new file mode 100644 index 0000000..591c0e6 --- /dev/null +++ b/capture/render/__init__.py @@ -0,0 +1,4 @@ +from .main import Renderer +from .scene import Scene, generate_random_scenes, generate_specular_scenes, gamma_decode, gamma_encode, encode_as_unit_interval, decode_from_unit_interval + +__all__ = ['Renderer', 'Scene', 'generate_random_scenes', 'generate_specular_scenes', 'gamma_decode', 'gamma_encode', 'encode_as_unit_interval', 'decode_from_unit_interval'] \ No newline at end of file diff --git a/capture/render/main.py b/capture/render/main.py new file mode 100644 index 0000000..b0f9d02 --- /dev/null +++ b/capture/render/main.py @@ -0,0 +1,153 @@ +import torch +import numpy as np + +from .scene import Light, Scene, Camera, dot_product, normalize, generate_normalized_random_direction, gamma_encode + + +class Renderer: + def __init__(self, return_params=False): + self.use_augmentation = False + self.return_params = return_params + + def xi(self, x): + return (x > 0.0) * torch.ones_like(x) + + def compute_microfacet_distribution(self, roughness, NH): + alpha = roughness**2 + alpha_squared = alpha**2 + NH_squared = NH**2 + denominator_part = torch.clamp(NH_squared * (alpha_squared + (1 - NH_squared) / NH_squared), min=0.001) + return (alpha_squared * self.xi(NH)) / (np.pi * denominator_part**2) + + def compute_fresnel(self, F0, VH): + # https://cdn2.unrealengine.com/Resources/files/2013SiggraphPresentationsNotes-26915738.pdf + return F0 + (1.0 - F0) * (1.0 - VH)**5 + + def compute_g1(self, roughness, XH, XN): + alpha = roughness**2 + alpha_squared = alpha**2 + XN_squared = XN**2 + return 2 * self.xi(XH / XN) / (1 + torch.sqrt(1 + alpha_squared * (1.0 - XN_squared) / XN_squared)) + + def compute_geometry(self, roughness, VH, LH, VN, LN): + return self.compute_g1(roughness, VH, VN) * self.compute_g1(roughness, LH, LN) + + def compute_specular_term(self, wi, wo, albedo, normals, roughness, metalness): + F0 = 0.04 * (1. - metalness) + metalness * albedo + + # Compute the half direction + H = normalize((wi + wo) / 2.0) + + # Precompute some dot product + NH = torch.clamp(dot_product(normals, H), min=0.001) + VH = torch.clamp(dot_product(wo, H), min=0.001) + LH = torch.clamp(dot_product(wi, H), min=0.001) + VN = torch.clamp(dot_product(wo, normals), min=0.001) + LN = torch.clamp(dot_product(wi, normals), min=0.001) + + F = self.compute_fresnel(F0, VH) + G = self.compute_geometry(roughness, VH, LH, VN, LN) + D = self.compute_microfacet_distribution(roughness, NH) + + return F * G * D / (4.0 * VN * LN) + + def compute_diffuse_term(self, albedo, metalness): + return albedo * (1. - metalness) / np.pi + + def evaluate_brdf(self, wi, wo, normals, albedo, roughness, metalness): + diffuse_term = self.compute_diffuse_term(albedo, metalness) + specular_term = self.compute_specular_term(wi, wo, albedo, normals, roughness, metalness) + return diffuse_term, specular_term + + def render(self, scene, svbrdf): + normals, albedo, roughness, displacement = svbrdf + device = albedo.device + + # Generate surface coordinates for the material patch + # The center point of the patch is located at (0, 0, 0) which is the center of the global coordinate system. + # The patch itself spans from (-1, -1, 0) to (1, 1, 0). + xcoords_row = torch.linspace(-1, 1, albedo.shape[-1], device=device) + xcoords = xcoords_row.unsqueeze(0).expand(albedo.shape[-2], albedo.shape[-1]).unsqueeze(0) + ycoords = -1 * torch.transpose(xcoords, dim0=1, dim1=2) + coords = torch.cat((xcoords, ycoords, torch.zeros_like(xcoords)), dim=0) + + # We treat the center of the material patch as focal point of the camera + camera_pos = scene.camera.pos.unsqueeze(-1).unsqueeze(-1).to(device) + relative_camera_pos = camera_pos - coords + wo = normalize(relative_camera_pos) + + # Avoid zero roughness (i. e., potential division by zero) + roughness = torch.clamp(roughness, min=0.001) + + light_pos = scene.light.pos.unsqueeze(-1).unsqueeze(-1).to(device) + relative_light_pos = light_pos - coords + wi = normalize(relative_light_pos) + + fdiffuse, fspecular = self.evaluate_brdf(wi, wo, normals, albedo, roughness, metalness=0) + f = fdiffuse + fspecular + + color = scene.light.color if torch.is_tensor(scene.light.color) else torch.tensor(scene.light.color) + light_color = color.unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device) + falloff = 1.0 / torch.sqrt(dot_product(relative_light_pos, relative_light_pos))**2 # Radial light intensity falloff + LN = torch.clamp(dot_product(wi, normals), min=0.0) # Only consider the upper hemisphere + radiance = torch.mul(torch.mul(f, light_color * falloff), LN) + + return radiance + + def _get_input_params(self, n_samples, light, pose): + min_eps = 0.001 + max_eps = 0.02 + light_distance = 2.197 + view_distance = 2.75 + + # Generate scenes (camera and light configurations) + # In the first configuration, the light and view direction are guaranteed to be perpendicular to the material sample. + # For the remaining cases, both are randomly sampled from a hemisphere. + view_dist = torch.ones(n_samples-1) * view_distance + if pose is None: + view_poses = torch.cat([torch.Tensor(2).uniform_(-0.25, 0.25), torch.ones(1) * view_distance], dim=-1).unsqueeze(0) + if n_samples > 1: + hemi_views = generate_normalized_random_direction(n_samples - 1, min_eps=min_eps, max_eps=max_eps) * view_distance + view_poses = torch.cat([view_poses, hemi_views]) + else: + assert torch.is_tensor(pose) + view_poses = pose.cpu() + + if light is None: + light_poses = torch.cat([torch.Tensor(2).uniform_(-0.75, 0.75), torch.ones(1) * light_distance], dim=-1).unsqueeze(0) + if n_samples > 1: + hemi_lights = generate_normalized_random_direction(n_samples - 1, min_eps=min_eps, max_eps=max_eps) * light_distance + light_poses = torch.cat([light_poses, hemi_lights]) + else: + assert torch.is_tensor(light) + light_poses = light.cpu() + + light_colors = torch.Tensor([10.0]).unsqueeze(-1).expand(n_samples, 3) + + return view_poses, light_poses, light_colors + + def __call__(self, svbrdf, n_samples=1, lights=None, poses=None): + view_poses, light_poses, light_colors = self._get_input_params(n_samples, lights, poses) + + renderings = [] + for wo, wi, c in zip(view_poses, light_poses, light_colors): + scene = Scene(Camera(wo), Light(wi, c)) + rendering = self.render(scene, svbrdf) + + # Simulate noise + std_deviation_noise = torch.exp(torch.Tensor(1).normal_(mean = np.log(0.005), std=0.3)).numpy()[0] + noise = torch.zeros_like(rendering).normal_(mean=0.0, std=std_deviation_noise) + + # clipping + post_noise = torch.clamp(rendering + noise, min=0.0, max=1.0) + + # gamma encoding + post_gamma = gamma_encode(post_noise) + + renderings.append(post_gamma) + + renderings = torch.cat(renderings, dim=0) + + if self.return_params: + return renderings, (view_poses, light_poses, light_colors) + return renderings \ No newline at end of file diff --git a/capture/render/scene.py b/capture/render/scene.py new file mode 100644 index 0000000..f144733 --- /dev/null +++ b/capture/render/scene.py @@ -0,0 +1,105 @@ +import math +import torch + + +def encode_as_unit_interval(tensor): + """ + Maps range [-1, 1] to [0, 1] + """ + return (tensor + 1) / 2 + +def decode_from_unit_interval(tensor): + """ + Maps range [0, 1] to [-1, 1] + """ + return tensor * 2 - 1 + +def gamma_decode(images): + return torch.pow(images, 2.2) + +def gamma_encode(images): + return torch.pow(images, 1.0/2.2) + +def dot_product(a, b): + return torch.sum(torch.mul(a, b), dim=-3, keepdim=True) + +def normalize(a): + return torch.div(a, torch.sqrt(dot_product(a, a))) + +def generate_normalized_random_direction(count, min_eps = 0.001, max_eps = 0.05): + r1 = torch.Tensor(count, 1).uniform_(0.0 + min_eps, 1.0 - max_eps) + r2 = torch.Tensor(count, 1).uniform_(0.0, 1.0) + + r = torch.sqrt(r1) + phi = 2 * math.pi * r2 + + x = r * torch.cos(phi) + y = r * torch.sin(phi) + z = torch.sqrt(1.0 - r**2) + + return torch.cat([x, y, z], axis=-1) + +def generate_random_scenes(count): + # Randomly distribute both, view and light positions + view_positions = generate_normalized_random_direction(count, 0.001, 0.1) # shape = [count, 3] + light_positions = generate_normalized_random_direction(count, 0.001, 0.1) + + scenes = [] + for i in range(count): + c = Camera(view_positions[i]) + # Light has lower power as the distance to the material plane is not as large + l = Light(light_positions[i], [20.]*3) + scenes.append(Scene(c, l)) + + return scenes + +def generate_specular_scenes(count): + # Only randomly distribute view positions and place lights in a perfect mirror configuration + view_positions = generate_normalized_random_direction(count, 0.001, 0.1) # shape = [count, 3] + light_positions = view_positions * torch.Tensor([-1.0, -1.0, 1.0]).unsqueeze(0) + + # Reference: "parameters chosen empirically to have a nice distance from a -1;1 surface."" + distance_view = torch.exp(torch.Tensor(count, 1).normal_(mean=0.5, std=0.75)) + distance_light = torch.exp(torch.Tensor(count, 1).normal_(mean=0.5, std=0.75)) + + # Reference: "Shift position to have highlight elsewhere than in the center." + # NOTE: This code only creates guaranteed specular highlights in the orthographic rendering, not in the perspective one. + # This is because the camera is -looking- at the center of the patch. + shift = torch.cat([torch.Tensor(count, 2).uniform_(-1.0, 1.0), torch.zeros((count, 1)) + 0.0001], dim=-1) + + view_positions = view_positions * distance_view + shift + light_positions = light_positions * distance_light + shift + + scenes = [] + for i in range(count): + c = Camera(view_positions[i]) + l = Light(light_positions[i], [20, 20.0, 20.0]) + scenes.append(Scene(c, l)) + + return scenes + +class Camera: + def __init__(self, pos): + self.pos = pos + def __str__(self): + return f'Camera({self.pos.tolist()})' + +class Light: + def __init__(self, pos, color): + self.pos = pos + self.color = color + def __str__(self): + return f'Light({self.pos.tolist()}, {self.color})' + +class Scene: + def __init__(self, camera, light): + self.camera = camera + self.light = light + def __str__(self): + return f'Scene({self.camera}, {self.light})' + @classmethod + def load(cls, o): + cam, light, color = o + return Scene(Camera(cam), Light(light, color)) + def export(self): + return [self.camera.pos, self.light.pos, self.light.color] diff --git a/capture/source/__init__.py b/capture/source/__init__.py new file mode 100644 index 0000000..75c5eab --- /dev/null +++ b/capture/source/__init__.py @@ -0,0 +1,5 @@ +from .model import ResnetEncoder, MultiHeadDecoder, DenseMTL +from .loss import DenseReg, RenderingLoss +from .routine import Vanilla + +__all__ = ['ResnetEncoder', 'MultiHeadDecoder', 'DenseMTL', 'DenseReg', 'RenderingLoss', 'Vanilla'] \ No newline at end of file diff --git a/capture/source/loss.py b/capture/source/loss.py new file mode 100644 index 0000000..8ac7104 --- /dev/null +++ b/capture/source/loss.py @@ -0,0 +1,143 @@ +from pathlib import Path + +import torch +import torch.nn as nn +from easydict import EasyDict +import torch.nn.functional as F +import torchvision.transforms.functional as tf + +from ..render import Renderer, Scene, generate_random_scenes, generate_specular_scenes + + +class RenderingLoss(nn.Module): + def __init__(self, renderer, n_random_configs=0, n_symmetric_configs=0): + super().__init__() + self.eps = 0.1 + self.renderer = renderer + self.n_random_configs = n_random_configs + self.n_symmetric_configs = n_symmetric_configs + self.n_renders = n_random_configs + n_symmetric_configs + + def generate_scenes(self): + return generate_random_scenes(self.n_random_configs) + generate_specular_scenes(self.n_symmetric_configs) + + def multiview_render(self, y, x): + X_renders, Y_renders = [], [] + + x_svBRDFs = zip(x.normals, x.albedo, x.roughness, x.displacement) + y_svBRDFs = zip(y.normals, y.albedo, y.roughness, x.displacement) + for x_svBRDF, y_svBRDF in zip(x_svBRDFs, y_svBRDFs): + x_renders, y_renders = [], [] + for scene in self.generate_scenes(): + x_renders.append(self.renderer.render(scene, x_svBRDF)) + y_renders.append(self.renderer.render(scene, y_svBRDF)) + X_renders.append(torch.cat(x_renders)) + Y_renders.append(torch.cat(y_renders)) + + out = torch.stack(X_renders), torch.stack(Y_renders) + return out + + def reconstruction(self, y, theta): + views = [] + for *svBRDF, t in zip(y.normals, y.albedo, y.roughness, y.displacement, theta): + render = self.renderer.render(Scene.load(t), svBRDF) + views.append(render) + return torch.cat(views) + + def __call__(self, y, x, **kargs): + loss = F.l1_loss(torch.log(y + self.eps), torch.log(x + self.eps), **kargs) + return loss + +class DenseReg(nn.Module): + def __init__( + self, + reg_weight: float, + render_weight: float, + pl_reg_weight: float = 0., + pl_render_weight: float = 0., + use_source: bool = True, + use_target: bool = True, + n_random_configs= 3, + n_symmetric_configs = 6, + ): + super().__init__() + + self.weights = [('albedo', reg_weight, self.log_l1), + ('roughness', reg_weight, self.log_l1), + ('normals', reg_weight, F.l1_loss)] + + self.reg_weight = reg_weight + self.render_weight = render_weight + self.pl_reg_weight = pl_reg_weight + self.pl_render_weight = pl_render_weight + self.use_source = use_source + self.use_target = use_target + + self.renderer = Renderer() + self.n_random_configs = n_random_configs + self.n_symmetric_configs = n_symmetric_configs + self.loss = RenderingLoss(self.renderer, n_random_configs=n_random_configs, n_symmetric_configs=n_symmetric_configs) + + def log_l1(self, x, y, **kwargs): + return F.l1_loss(torch.log(x + 0.01), torch.log(y + 0.01), **kwargs) + + def forward(self, x, y): + loss = EasyDict() + x_src, x_tgt = x + y_src, y_tgt = y + + if self.use_source: + # acg regression loss + for k, w, loss_fn in self.weights: + loss[k] = w*loss_fn(y_src[k], x_src[k]) + + # rendering loss + x_src.image, y_src.image = self.loss.multiview_render(y_src, x_src) + loss.render = self.render_weight*self.loss(y_src.image, x_src.image) + + # reconstruction + y_src.reco = self.loss.reconstruction(y_src, x_src.input_params) + + if self.use_target: + for k, w, loss_fn in self.weights: + loss[f'tgt_{k}'] = self.pl_reg_weight*loss_fn(y_tgt[k], x_tgt[k]) + + # rendering loss w/ pseudo label + y_tgt.image, x_tgt.image = self.loss.multiview_render(y_tgt, x_tgt) + loss.sd_render = self.pl_render_weight*self.loss(y_tgt.image, x_tgt.image) + + # reconstruction + y_tgt.reco = self.loss.reconstruction(y_tgt, x_tgt.input_params) + + loss.total = torch.stack(list(loss.values())).sum() + return loss + + @torch.no_grad() + def test(self, x, y, batch_idx, epoch, dl_id): + assert len(x.name) == 1 + y.reco = self.loss.reconstruction(y, x.input_params) + return EasyDict(total=0) + + @torch.no_grad() + def predict(self, x_tgt, y_tgt, batch_idx, split, epoch): + assert len(x_tgt.name) == 1 + + # gt components + I = x_tgt.input[0] + name = x_tgt.name[0] + + # get the predicted maps + N_pred = y_tgt.normals[0] + A_pred = y_tgt.albedo[0] + R_pred = y_tgt.roughness[0] + + # A_name = pl_path/f'{name}_albedo.png' + # save_image(A_pred, A_name) + + # N_name = pl_path/f'{name}_normals.png' + # save_image(encode_as_unit_interval(N_pred), N_name) + + # R_name = pl_path/f'{name}_roughness.png' + # save_image(R_pred, R_name) + + return EasyDict(total=0) diff --git a/capture/source/model.py b/capture/source/model.py new file mode 100644 index 0000000..7c028e5 --- /dev/null +++ b/capture/source/model.py @@ -0,0 +1,233 @@ +# Adapted from monodepth2 +# https://github.com/nianticlabs/monodepth2/blob/master/networks/depth_decoder.py +# +# Copyright Niantic 2019. Patent Pending. All rights reserved. +# +# This software is licensed under the terms of the Monodepth2 licence +# which allows for non-commercial use only, the full terms of which are made +# available in the LICENSE file. + +from __future__ import absolute_import, division, print_function +from collections import OrderedDict +from easydict import EasyDict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +import torch.utils.model_zoo as model_zoo + + +class ConvBlock(torch.nn.Module): + """Layer to perform a convolution followed by ELU.""" + def __init__(self, in_channels, out_channels, bn=False, dropout=0.0): + super(ConvBlock, self).__init__() + + self.block = nn.Sequential( + Conv3x3(in_channels, out_channels), + nn.BatchNorm2d(out_channels) if bn else nn.Identity(), + nn.ELU(inplace=True), + # Pay attention: 2d version of dropout is used + nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()) + + def forward(self, x): + out = self.block(x) + return out + + +class Conv3x3(nn.Module): + """Layer to pad and convolve input with 3x3 kernels.""" + def __init__(self, in_channels, out_channels, use_refl=True): + super(Conv3x3, self).__init__() + + if use_refl: + self.pad = nn.ReflectionPad2d(1) + else: + self.pad = nn.ZeroPad2d(1) + self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) + + def forward(self, x): + out = self.pad(x) + out = self.conv(out) + return out + +def upsample(x): + """Upsample input tensor by a factor of 2.""" + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class ResNetMultiImageInput(models.ResNet): + """Constructs a resnet model with varying number of input images. + Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py + """ + def __init__(self, block, layers, num_classes=1000, in_channels=3): + super(ResNetMultiImageInput, self).__init__(block, layers) + self.inplanes = 64 + self.conv1 = nn.Conv2d( + in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + +def resnet_multiimage_input(num_layers, pretrained=False, in_channels=3): + """Constructs a ResNet model. + Args: + num_layers (int): Number of resnet layers. Must be 18 or 50 + pretrained (bool): If True, returns a model pre-trained on ImageNet + in_channels (int): Number of input channels + """ + assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet" + blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] + block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers] + model = ResNetMultiImageInput(block_type, blocks, in_channels=in_channels) + + if pretrained: + print('loading imagnet weights on resnet...') + loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)]) + # loaded['conv1.weight'] = torch.cat( + # (loaded['conv1.weight'], loaded['conv1.weight']), 1) + # diff = model.load_state_dict(loaded, strict=False) + return model + + +class ResnetEncoder(nn.Module): + """Pytorch module for a resnet encoder + """ + def __init__(self, num_layers, pretrained, in_channels=3): + super(ResnetEncoder, self).__init__() + + self.num_ch_enc = np.array([64, 64, 128, 256, 512]) + + resnets = {18: models.resnet18, + 34: models.resnet34, + 50: models.resnet50, + 101: models.resnet101, + 152: models.resnet152} + + if num_layers not in resnets: + raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) + + if in_channels > 3: + self.encoder = resnet_multiimage_input(num_layers, pretrained, in_channels) + else: + weights = models.ResNet101_Weights.IMAGENET1K_V1 if pretrained else None + self.encoder = resnets[num_layers](weights=weights) + + if num_layers > 34: + self.num_ch_enc[1:] *= 4 + + def forward(self, x): + self.features = [] + + # input_image, normals = xx + # x = (input_image - 0.45) / 0.225 + # x = torch.cat((input_image, normals),1) + x = self.encoder.conv1(x) + x = self.encoder.bn1(x) + self.features.append(self.encoder.relu(x)) + self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1]))) + self.features.append(self.encoder.layer2(self.features[-1])) + self.features.append(self.encoder.layer3(self.features[-1])) + self.features.append(self.encoder.layer4(self.features[-1])) + + return self.features + + +class Decoder(nn.Module): + def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True, + kaiming_init=False, return_feats=False): + super().__init__() + + self.num_output_channels = num_output_channels + self.use_skips = use_skips + self.upsample_mode = 'nearest' + self.scales = scales + + self.return_feats = return_feats + + self.num_ch_enc = num_ch_enc + self.num_ch_dec = np.array([16, 32, 64, 128, 256]) + + # decoder + self.convs = OrderedDict() + for i in range(4, -1, -1): + # upconv_0 + num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] + num_ch_out = self.num_ch_dec[i] + self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) + + # upconv_1 + num_ch_in = self.num_ch_dec[i] + if self.use_skips and i > 0: + num_ch_in += self.num_ch_enc[i - 1] + num_ch_out = self.num_ch_dec[i] + self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) + + # for s in self.scales: + self.convs[("dispconv", 0)] = Conv3x3(self.num_ch_dec[0], self.num_output_channels) + + self.decoder = nn.ModuleList(list(self.convs.values())) + # self.sigmoid = nn.Sigmoid() + + if kaiming_init: + print('init weights of decoder') + for m in self.children(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + if m.bias is not None: + m.bias.data.fill_(0.01) + + def forward(self, input_features): + x = input_features[-1] + for i in range(4, -1, -1): + x = self.convs[("upconv", i, 0)](x) + x = [upsample(x)] + if self.use_skips and i > 0: + x += [input_features[i - 1]] + x = torch.cat(x, 1) + x = self.convs[("upconv", i, 1)](x) + + # assert self.scales[0] == 0 + final_conv = self.convs[("dispconv", 0)] + out = final_conv(x) + + if self.return_feats: + return out, input_features[-1] + return out + +class MultiHeadDecoder(nn.Module): + def __init__(self, num_ch_enc, tasks, return_feats, use_skips): + super().__init__() + self.decoders = nn.ModuleDict({k: + Decoder(num_ch_enc=num_ch_enc, + num_output_channels=num_ch, + scales=[0], + kaiming_init=False, + use_skips=use_skips, + return_feats=return_feats) + for k, num_ch in tasks.items()}) + + def forward(self, x): + y = EasyDict({k: v(x) for k, v in self.decoders.items()}) + return y + +class DenseMTL(nn.Module): + def __init__(self, encoder, decoder): + super().__init__() + self.encoder = encoder + self.decoder = decoder + def forward(self, x): + return self.decoder(self.encoder(x)) \ No newline at end of file diff --git a/capture/source/routine.py b/capture/source/routine.py new file mode 100644 index 0000000..1969c53 --- /dev/null +++ b/capture/source/routine.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +from torch import optim +from pathlib import Path +from easydict import EasyDict +import pytorch_lightning as pl +import torch.nn.functional as F +import torchvision.transforms as T +from torchvision.utils import save_image +from torchmetrics import MeanSquaredError, StructuralSimilarityIndexMeasure + +from . import DenseReg, RenderingLoss +from ..render import Renderer, encode_as_unit_interval, gamma_decode, gamma_encode + + +class Vanilla(pl.LightningModule): + metrics = ['I_mse','N_mse','A_mse','R_mse','I_ssim','N_ssim','A_ssim','R_ssim'] + maps = {'I': 'reco', 'N': 'normals', 'R': 'roughness', 'A': 'albedo'} + + def __init__(self, model: nn.Module, loss: DenseReg = None, lr: float = 0, batch_size: int = 0): + super().__init__() + self.model = model + self.loss = loss + self.lr = lr + self.batch_size = batch_size + self.tanh = nn.Tanh() + self.norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + def training_step(self, x): + y = self(*x) + loss = self.loss(x, y) + self.log_to('train', loss) + return dict(loss=loss.total, y=y) + + def forward(self, src, tgt): + src_out, tgt_out = None, None + + if None not in src: + src_out = self.model(self.norm(src.input)) + self.post_process_(src_out) + + if None not in tgt: + tgt_out = self.model(self.norm(tgt.input)) + self.post_process_(tgt_out) + + return src_out, tgt_out + + def post_process_(self, o: EasyDict): + # (1) activation function, (2) concat unit z, (3) normalize to unit vector + nxy = self.tanh(o.normals) + nx, ny = torch.split(nxy*3, split_size_or_sections=1, dim=1) + n = torch.cat([nx, ny, torch.ones_like(nx)], dim=1) + o.normals = F.normalize(n, dim=1) + + # (1) activation function, (2) mapping [-1,1]->[0,1] + a = self.tanh(o.albedo) + o.albedo = encode_as_unit_interval(a) + + # (1) activation function, (2) mapping [-1,1]->[0,1], (3) channel repeat x3 + r = self.tanh(o.roughness) + o.roughness = encode_as_unit_interval(r.repeat(1,3,1,1)) + + def validation_step(self, x, *_): + y = self(*x) + loss = self.loss(x, y) + self.log_to('val', loss) + return dict(loss=loss.total, y=y) + + def log_to(self, split, loss): + self.log_dict({f'{split}/{k}': v for k, v in loss.items()}, batch_size=self.batch_size) + + def on_test_start(self): + self.renderer = RenderingLoss(Renderer()) + + for m in Vanilla.metrics: + if 'mse' in m: + setattr(self, m, MeanSquaredError().to(self.device)) + elif 'ssim' in m: + setattr(self, m, StructuralSimilarityIndexMeasure(data_range=1).to(self.device)) + + def test_step(self, x, batch_idx, dl_id=0): + y = self.model(self.norm(x.input)) + self.post_process_(y) + + # image reconstruction + y.reco = self.renderer.reconstruction(y, x.input_params) + x.reco = gamma_decode(x.input) + + for m in Vanilla.metrics: + mapid, *_ = m + k = Vanilla.maps[mapid] + meter = getattr(self, m) + meter(y[k], x[k].to(y[k].dtype)) + self.log(m, getattr(self, m), on_epoch=True) + + def predict_step(self, x, batch_idx): + y = self.model(self.norm(x.input)) + self.post_process_(y) + + I, name, outdir = x.input[0], x.name[0], Path(x.path[0]).parent + N_pred, A_pred, R_pred = y.normals[0], y.albedo[0], y.roughness[0] + + save_image(gamma_encode(A_pred), outdir/f'{name}_albedo.png') + save_image(encode_as_unit_interval(N_pred), outdir/f'{name}_normals.png') + save_image(R_pred, outdir/f'{name}_roughness.png') + + def configure_optimizers(self): + optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-4) + return dict(optimizer=optimizer) \ No newline at end of file diff --git a/capture/utils/__init__.py b/capture/utils/__init__.py new file mode 100644 index 0000000..9634c87 --- /dev/null +++ b/capture/utils/__init__.py @@ -0,0 +1,8 @@ + +from .model import get_module +from .cli import get_args +from .exp import Trainer, get_name, get_callbacks, get_data +from .log import get_logger + + +__all__ = ['get_model', 'get_module', 'get_args', 'get_name', 'get_logger', 'get_data', 'get_callbacks', 'Trainer'] \ No newline at end of file diff --git a/capture/utils/cli.py b/capture/utils/cli.py new file mode 100644 index 0000000..e58c063 --- /dev/null +++ b/capture/utils/cli.py @@ -0,0 +1,76 @@ + +from pathlib import Path + +import jsonargparse +import torch.nn as nn +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger + +from ..source import Vanilla, DenseReg +from ..callbacks import VisualizeCallback +from ..data.module import DataModule + + +#! refactor this simplification required +class LightningArgumentParser(jsonargparse.ArgumentParser): + """ + Extension of jsonargparse.ArgumentParser to parse pl.classes and more. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def add_datamodule(self, datamodule_obj: pl.LightningDataModule): + self.add_method_arguments(datamodule_obj, '__init__', 'data', as_group=True) + + def add_lossmodule(self, lossmodule_obj: nn.Module): + self.add_class(lossmodule_obj, 'loss') + + def add_routine(self, model_obj: pl.LightningModule): + skip = {'ae', 'decoder', 'loss', 'transnet', 'model', 'discr', 'adv_loss', 'stage'} + self.add_class_arguments(model_obj, 'routine', as_group=True, skip=skip) + + def add_logger(self, logger_obj): + skip = {'version', 'config', 'name', 'save_dir'} + self.add_class_arguments(logger_obj, 'logger', as_group=True, skip=skip) + + def add_class(self, cls, group, **kwargs): + self.add_class_arguments(cls, group, as_group=True, **kwargs) + + def add_trainer(self): + skip = {'default_root_dir', 'logger', 'callbacks'} + self.add_class_arguments(pl.Trainer, 'trainer', as_group=True, skip=skip) + +def get_args(datamodule=DataModule, loss=DenseReg, routine=Vanilla, viz=VisualizeCallback): + parser = LightningArgumentParser() + + parser.add_argument('--config', action=jsonargparse.ActionConfigFile, required=True) + parser.add_argument('--archi', type=str, required=True) + parser.add_argument('--out_dir', type=lambda x: Path(x), required=True) + + parser.add_argument('--seed', default=666, type=int) + parser.add_argument('--load_weights_from', type=lambda x: Path(x)) + parser.add_argument('--save_ckpt_every', default=10, type=int) + parser.add_argument('--wandb', action='store_true', default=False) + parser.add_argument('--mode', choices=['train', 'eval', 'test', 'predict'], default='train', type=str) + parser.add_argument('--resume_from', default=None, type=str) + + if datamodule is not None: + parser.add_datamodule(datamodule) + + if loss is not None: + parser.add_lossmodule(loss) + + if routine is not None: + parser.add_routine(routine) + + if viz is not None: + parser.add_class_arguments(viz, 'viz', skip={'out_dir', 'exist_ok'}) + + # bindings between modules (data/routine/loss) + parser.link_arguments('data.batch_size', 'routine.batch_size') + + parser.add_logger(WandbLogger) + parser.add_trainer() + + args = parser.parse_args() + return args \ No newline at end of file diff --git a/capture/utils/exp.py b/capture/utils/exp.py new file mode 100644 index 0000000..5c61b9d --- /dev/null +++ b/capture/utils/exp.py @@ -0,0 +1,77 @@ +import os + +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning import Trainer as plTrainer +from pathlib import Path + +from ..callbacks import VisualizeCallback, MetricLogging +from ..data.module import DataModule +from .log import get_info + + +def get_data(args=None, **kwargs): + if args is None: + return DataModule(**kwargs) + else: + return DataModule(**args.data) + +def get_name(args) -> str: + name = ''#f'{args.mode}' + + src_ds_verbose = str(args.data.source_list).split(os.sep)[-1].replace("_","-").upper() + + if args.mode == 'train': + if args.loss.use_source and not args.loss.use_target: + name += f'pretrain_ds{args.data.source_ds.upper()}_lr{args.routine.lr}_x{args.data.input_size}_bs{args.data.batch_size}_reg{args.loss.reg_weight}_rend{args.loss.render_weight}_ds{str(args.data.source_list).split(os.sep)[-1].replace("_","-").upper()}' + + elif args.loss.use_target: + name += f'_F_{args.data.target_ds.upper()}_lr{args.routine.lr}_x{args.data.input_size}_bs{args.data.tgt_bs}_aug{int(args.data.transform)}_reg{args.loss.pl_reg_weight}_rend{args.loss.pl_render_weight}_ds{str(args.data.target_list).split(os.sep)[-1].replace("_","-").upper()}' + else: + name += f'_T_{args.data.source_ds.upper()}_lr{args.routine.lr}_x{args.data.input_size}_aug{int(args.data.transform)}_reg{args.loss.render_weight}_rend{args.loss.render_weight}_ds{str(args.data.source_list).split(os.sep)[-1].replace("_","-").upper()}' + + if args.loss.adv_weight: + name += f'_ADV{args.loss.adv_weight}' + if args.data.source_ds == 'acg': + name += f'_mixbs{args.data.batch_size}' + if args.loss.reg_weight != 0.1: + name += f'_regSRC{args.loss.reg_weight}' + if args.loss.render_weight != 1: + name += f'_rendSRC{args.loss.render_weight}' + if args.data.use_ref: + name += '_useRef' + if args.load_weights_from: + wname, epoch = get_info(str(args.load_weights_from)) + assert wname and epoch + name += f'_init{wname.replace("_", "-")}-{epoch}ep' + + name += f'_s{args.seed}' + return name + # name += args.load_weights_from.split(os.sep)[-1][:-5] + +def get_callbacks(args): + callbacks = [ + VisualizeCallback(out_dir=args.out_dir, exist_ok=bool(args.resume_from), **args.viz), + ModelCheckpoint( + dirpath=args.out_dir/'ckpt', + filename='{name}_{epoch}-{step}', + save_weights_only=False, + save_top_k=-1, + every_n_epochs=args.save_ckpt_every), + MetricLogging(args.load_weights_from, args.data.test_list, outdir=Path('./logs')), + ] + return callbacks + +class Trainer(plTrainer): + def __init__(self, o_args, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ckpt_path = o_args.resume_from + + def __call__(self, mode, module, data) -> None: + if mode == 'test': + self.test(module, data) + elif mode == 'eval': + self.validate(module, data) + elif mode == 'predict': + self.predict(module, data) + elif mode == 'train': + self.fit(module, data, ckpt_path=self.ckpt_path) \ No newline at end of file diff --git a/capture/utils/log.py b/capture/utils/log.py new file mode 100644 index 0000000..31af842 --- /dev/null +++ b/capture/utils/log.py @@ -0,0 +1,47 @@ +import csv, re, os + +from pytorch_lightning.loggers import TensorBoardLogger + + +def read_csv(fname): + with open(fname, 'r') as f: + reader = csv.DictReader(f) + return list(reader) + +def append_csv(fname, dicts): + if isinstance(dicts, dict): + dicts = [dicts] + + if os.path.isfile(fname): + dicts = read_csv(fname) + dicts + + write_csv(fname, dicts) + +def write_csv(fname, dicts): + assert len(dicts) > 0 + with open(fname, 'w', newline='') as f: + writer = csv.DictWriter(f, fieldnames=dicts[0].keys()) + writer.writeheader() + for d in dicts: + writer.writerow(d) + +def now(): + from datetime import datetime + return datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + +def get_info(weights: str): + search = re.search(r"(.*)_epoch=(\d+)-step", weights) + if search: + name, epoch = search.groups() + return str(name).split(os.sep)[-1], str(epoch) + return None, None + +def get_matlist(cache_dir, dir): + with open(cache_dir, 'r') as f: + content = f.readlines() + files = [dir/f.strip() for f in content] + return files + +def get_logger(args): + logger = TensorBoardLogger(save_dir=args.out_dir) + logger.log_hyperparams(args) diff --git a/capture/utils/model.py b/capture/utils/model.py new file mode 100644 index 0000000..9896887 --- /dev/null +++ b/capture/utils/model.py @@ -0,0 +1,43 @@ +import torch.nn as nn + +from pathlib import Path + +from ..source import ResnetEncoder, MultiHeadDecoder, DenseMTL, DenseReg, Vanilla + + +def replace_batchnorm_(module: nn.Module): + for name, child in module.named_children(): + if isinstance(child, nn.BatchNorm2d): + setattr(module, name, nn.InstanceNorm2d(child.num_features)) + else: + replace_batchnorm_(child) + +def get_model(archi): + assert archi == 'densemtl' + + encoder = ResnetEncoder(num_layers=101, pretrained=True, in_channels=3) + decoder = MultiHeadDecoder( + num_ch_enc=encoder.num_ch_enc, + tasks=dict(albedo=3, roughness=1, normals=2), + return_feats=False, + use_skips=True) + + model = nn.Sequential(encoder, decoder) + replace_batchnorm_(model) + return model + +def get_module(args): + loss = DenseReg(**args.loss) + model = get_model(args.archi) + + weights = args.load_weights_from + if weights: + assert weights.is_file() + return Vanilla.load_from_checkpoint(str(weights), model=model, loss=loss, strict=False, **args.routine) + + return Vanilla(model, loss, **args.routine) + +def get_inference_module(pt): + assert Path(pt).exists() + model = get_model('densemtl') + return Vanilla.load_from_checkpoint(str(pt), model=model, strict=False) \ No newline at end of file diff --git a/concept/__init__.py b/concept/__init__.py new file mode 100644 index 0000000..e541c00 --- /dev/null +++ b/concept/__init__.py @@ -0,0 +1,6 @@ +from .crop import main as crop +from .invert import invert +from .infer import infer +from .renorm import renorm + +__all__ = ['crop', 'invert', 'infer', 'renorm'] \ No newline at end of file diff --git a/concept/args.py b/concept/args.py new file mode 100644 index 0000000..352a616 --- /dev/null +++ b/concept/args.py @@ -0,0 +1,202 @@ +import os +from argparse import ArgumentParser + + +def get_argparse_defaults(parser): + # https://stackoverflow.com/questions/44542605/python-how-to-get-all-default-values-from-argparse + defaults = {} + for action in parser._actions: + if not action.required and action.dest != "help": + defaults[action.dest] = action.default + return defaults + +def parse_args(return_defaults=False): + parser = ArgumentParser() + + parser.add_argument('--pretrained_model_name_or_path', type=str, default='runwayml/stable-diffusion-v1-5', + help='Path to pretrained model or model identifier from huggingface.co/models.') + + parser.add_argument('--revision', type=str, default=None, required=False, + help='Revision of pretrained model identifier from huggingface.co/models.') + + parser.add_argument('--seed', type=int, default=1, + help='A seed for reproducible training.') + + parser.add_argument('--local_rank', type=int, default=-1, + help='For distributed training: local_rank') + + + ## Dataset + parser.add_argument('--path', type=str, required=True, + help='A folder containing the training data of instance images.') + + parser.add_argument('--prompt', type=str, default='an object with azertyuiop texture', + help='The prompt with identifier specifying the instance') + + parser.add_argument('--tokenizer_name', type=str, default=None, + help='Pretrained tokenizer name or path if not the same as model_name') + + parser.add_argument('--resolution', type=int, default=256, # 512 + help='Resolution of train/validation images') + + + ## LoRA options + parser.add_argument('--use_lora', action='store_true', default=True, # overwrite + help='Whether to use Lora for parameter efficient tuning') + + parser.add_argument('--lora_r', type=int, default=16, # 8 + help='Lora rank, only used if use_lora is True') + + parser.add_argument('--lora_alpha', type=int, default=27, # 32 + help='Lora alpha, only used if use_lora is True') + + parser.add_argument('--lora_dropout', type=float, default=0.0, + help='Lora dropout, only used if use_lora is True') + + parser.add_argument('--lora_bias', type=str, default='none', + help='Bias type for Lora: ["none", "all", "lora_only"], only used if use_lora is True') + + parser.add_argument('--lora_text_encoder_r', type=int, default=16, # 8 + help='Lora rank for text encoder, only used if `use_lora` & `train_text_encoder` are True') + + parser.add_argument('--lora_text_encoder_alpha', type=int, default=17, # 32 + help='Lora alpha for text encoder, only used if `use_lora` & `train_text_encoder` are True') + + parser.add_argument('--lora_text_encoder_dropout', type=float, default=0.0, + help='Lora dropout for text encoder, only used if `use_lora` & `train_text_encoder` are True') + + parser.add_argument('--lora_text_encoder_bias', type=str, default='none', + help='Bias type for Lora: ["none", "all", "lora_only"] when `use_lora` & `train_text_encoder` are True') + + + ## Training hyperparameters + parser.add_argument('--train_text_encoder', action='store_true', + help='Whether to train the text encoder') + + parser.add_argument('--train_batch_size', type=int, default=1, + help='Batch size (per device) for the training dataloader.') + + # parser.add_argument('--num_train_epochs', type=int, default=1, + # help="Number of training epochs, used when `max_train_steps` is not set.") + + parser.add_argument('--max_train_steps', type=int, default=800, + help='Total number of training steps to perform.') + + # parser.add_argument('--gradient_accumulation_steps', type=int, default=1, + # help='Number of updates steps to accumulate before performing a backward/update pass.') + + parser.add_argument('--gradient_checkpointing', action='store_true', + help='Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.') + + parser.add_argument('--scale_lr', action='store_true', default=False, + help='Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.') + + parser.add_argument('--lr_scheduler', type=str, default='constant', + choices=['linear', 'cosine', 'cosine_with_restarts', 'polynomial', 'constant', 'constant_with_warmup'], + help='The scheduler type to use.') + + parser.add_argument('--lr_warmup_steps', type=int, default=0, # 500 + help='Number of steps for the warmup in the lr scheduler.') + + parser.add_argument('--lr_num_cycles', type=int, default=1, + help='Number of hard resets of the lr in cosine_with_restarts scheduler.') + + parser.add_argument('--lr_power', type=float, default=1.0, + help='Power factor of the polynomial scheduler.') + + parser.add_argument('--adam_beta1', type=float, default=0.9, + help='The beta1 parameter for the Adam optimizer.') + + parser.add_argument('--adam_beta2', type=float, default=0.999, + help='The beta2 parameter for the Adam optimizer.') + + parser.add_argument('--adam_weight_decay', type=float, default=1e-2, + help='Weight decay to use.') + + parser.add_argument('--adam_epsilon', type=float, default=1e-08, + help='Epsilon value for the Adam optimizer') + + parser.add_argument('--max_grad_norm', default=1.0, type=float, + help='Max gradient norm.') + + parser.add_argument('--learning_rate', type=float, default=1e-4, + help='Initial learning rate (after the potential warmup period) to use.') + + + ## Prior preservation loss + # parser.add_argument('--with_prior_preservation', default=False, action='store_true', + # help='Flag to add prior preservation loss.') + + # parser.add_argument('--prior_loss_weight', type=float, default=1.0, + # help='The weight of prior preservation loss.') + + # parser.add_argument('--class_data_dir', type=str, default=None, required=False, + # help='A folder containing the training data of class images.') + + # parser.add_argument('--class_prompt', type=str, default=None, + # help='The prompt to specify images in the same class as provided instance images.') + + # parser.add_argument('--num_class_images', type=int, default=100, + # help='Min number for prior preservation loss, if lower, more images will be sampled with `class_prompt`.') + + # parser.add_argument('--prior_generation_precision', type=str, default=None, + # choices=['no', 'fp32', 'fp16', 'bf16'], + # help='Precision type for prior generation (bf16 requires PyTorch>= 1.10 + Nvidia Ampere GPU)') + + ## Logs + parser.add_argument('--checkpointing_steps', type=int, default=800, + help='Save a checkpoint every X steps, can be used to resume training w/ `--resume_from_checkpoint`.') + + parser.add_argument('--resume_from_checkpoint', type=str, default=None, + help='Resume from checkpoint obtained w/ `--checkpointing_steps`, or `"latest"`.') + + parser.add_argument('--validation_prompt', type=str, default=None, + help='A prompt that is used during validation to verify that the model is learning.') + + parser.add_argument('--num_validation_images', type=int, default=4, + help='Number of images that should be generated during validation with `validation_prompt`.') + + parser.add_argument('--validation_steps', type=int, default=100, + help='Run validation every X steps: runs w/ prompt `args.validation_prompt` `args.num_validation_images` times.') + + # parser.add_argument('--output_dir', type=Path, default=None, + # help='The output directory where the model predictions and checkpoints will be written.') + + parser.add_argument('--logging_dir', type=str, default='logs', + help='TensorBoard log directory, defaults default to `output_dir`/runs/**CURRENT_DATETIME_HOSTNAME***.') + + parser.add_argument('--report_to', type=str, default='tensorboard', + choices=['tensorboard', 'wandb', 'comet_ml', 'all'], + help='The integration to report the results and logs to.') + + parser.add_argument('--wandb_key', type=str, default=None, + help='If report to option is set to wandb, api-key for wandb used for login to wandb.') + + parser.add_argument('--wandb_project_name', type=str, default=None, + help='If report to option is set to wandb, project name in wandb for log tracking.') + + + ## Advanced options + parser.add_argument('--use_8bit_adam', action='store_true', + help='Whether or not to use 8-bit Adam from bitsandbytes.') + + parser.add_argument('--allow_tf32', action='store_true', + help='Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training.') + + parser.add_argument('--mixed_precision', type=str, default='fp16', + choices=['no', 'fp16', 'bf16'], + help='Precision type (bf16 requires PyTorch>= 1.10 + Nvidia Ampere GPU)') + + parser.add_argument('--enable_xformers_memory_efficient_attention', action='store_true', + help='Whether or not to use xformers.') + + if return_defaults: + return get_argparse_defaults(parser) + + args = parser.parse_args() + + env_local_rank = int(os.environ.get('LOCAL_RANK', -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args \ No newline at end of file diff --git a/concept/crop.py b/concept/crop.py new file mode 100644 index 0000000..b50a18d --- /dev/null +++ b/concept/crop.py @@ -0,0 +1,98 @@ +from PIL import Image +from random import shuffle + +import torchvision.transforms.functional as tf + + +def main(path, patch_sizes=[512, 256, 192, 128, 64], threshold=.99, topk=100): + assert path.is_dir(), \ + f'the provided path {path} is not a directory or does not exist' + + masks_dir = path/'masks' + assert masks_dir.is_dir(), \ + f'a /masks subdirectory containing the image masks should be present in {path}' + + files = [x for x in path.iterdir() if x.is_file()] + assert len(files) == 1, \ + f'the target path {path} should contain a single image file!' + + img_path = files[0] + print(f'---- processing image "{img_path.name}"') + + out_dir = img_path.parent/'crops' + out_dir.mkdir(parents=True, exist_ok=True) + + pil_ref = Image.open(img_path).convert('RGB') + img_shape = (pil_ref.width, pil_ref.height) + ref = tf.to_tensor(pil_ref) + + k = 0 + masks = sorted(x for x in masks_dir.iterdir()) + print(f' found {len(masks)} masks...') + + for i, f in enumerate(masks): + clusterdir = out_dir/f.stem + + if clusterdir.exists(): + k += 1 + continue + + pil_mask = Image.open(f).convert('RGB').resize(img_shape) + + main_bbox = pil_mask.convert('L').point(lambda x: 0 if x == 0 else 1, '1').getbbox() + x0, y0, *_ = main_bbox + + cropped_mask = tf.to_tensor(pil_mask.crop(main_bbox)) > 0 + + mask_d = int(cropped_mask[0].float().sum()) + print(f' > "{f.stem}" cluster, q={cropped_mask[0].float().mean():.2%}') + + kept_bboxes = [] + kept_scales = [] + for patch_size in patch_sizes: + stride = patch_size//5 + densities, bboxes = patch_image(cropped_mask, patch_size, stride, x0, y0) + + kept_local_res = [] + for d, b in zip(densities, bboxes): + if d >= threshold: + kept_local_res.append(b) + + shuffle(kept_local_res) + nb_kept = topk - len(kept_bboxes) + kept_local_res = kept_local_res[:nb_kept] + + kept_bboxes += kept_local_res + kept_scales += [patch_size]*len(kept_local_res) + + print(f' {patch_size}x{patch_size} kept {len(kept_local_res)} patches -> {clusterdir}') + + if len(kept_local_res) > 0: # only take largest scale + break + + if len(kept_bboxes) < 2: + print(f' skipping, only found {len(kept_bboxes)} patches.') + continue + + clusterdir.mkdir(exist_ok=True) + for i, (s, b) in enumerate(zip(kept_scales, kept_bboxes)): + cname = clusterdir/f'{i:0>5}_x{s}.png' + pil_ref.crop(b).save(cname) + + k += 1 + + print(f'---- kept {k}/{len(masks)} crops.') + + return out_dir + +def patch_image(mask, psize, stride, x0, y0): + densities, bboxes = [], [] + height, width = mask.shape[-2:] + for j in range(0, height - psize + 1, stride): + for i in range(0, width - psize + 1, stride): + patch = mask[0, j:j+psize, i:i+psize] + density = patch.float().mean().item() + densities.append(density) + bbox = x0+i, y0+j, x0+i+psize, y0+j+psize + bboxes.append(bbox) + return densities, bboxes \ No newline at end of file diff --git a/concept/data.py b/concept/data.py new file mode 100644 index 0000000..28188c3 --- /dev/null +++ b/concept/data.py @@ -0,0 +1,58 @@ +from pathlib import Path +import math +from PIL import Image + +import torch +import torchvision.transforms.functional as tf +import torchvision.transforms as T +from torchvision.transforms import RandomRotation +import torch.utils.checkpoint +from torch.utils.data import Dataset + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__(self, data_dir, prompt, tokenizer, size=512): + super().__init__() + self.size = size + self.tokenizer = tokenizer + + self.data_dir = Path(data_dir) + if not self.data_dir.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = [x for x in data_dir.iterdir() if x.is_file()] + assert len(self) > 0, 'data directory is empty' + + self.prompt = prompt + + self.image_transforms = T.Compose([ + T.RandomHorizontalFlip(), + T.RandomVerticalFlip(), + T.Resize(size, interpolation=T.InterpolationMode.BILINEAR), + T.ToTensor(), + T.Normalize([0.5], [0.5]), + ]) + + def __len__(self): + return len(self.instance_images_path) + + def __getitem__(self, index): + image = Image.open(self.instance_images_path[index % len(self)]) + if not image.mode == "RGB": + image = image.convert("RGB") + + img = self.image_transforms(image) + prompt = self.tokenizer( + self.prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids[0] + + return img, prompt \ No newline at end of file diff --git a/concept/infer.py b/concept/infer.py new file mode 100644 index 0000000..68484a6 --- /dev/null +++ b/concept/infer.py @@ -0,0 +1,412 @@ +import os +import argparse +from pathlib import Path +import random +from itertools import product +from argparse import Namespace + +import torch +import numpy as np +from tqdm import tqdm +import torch.nn.functional as F +import torchvision.transforms.functional as tf +from torchvision.utils import save_image +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg + + +def parse_args(): + parser = argparse.ArgumentParser(description="Inference code for generating samples from concept.") + parser.add_argument('path', type=Path, default=None) + parser.add_argument('--outdir', type=Path, default=None) + parser.add_argument('--token', type=str, default=None) + parser.add_argument('--stitch_mode', type=str, default='wmean', choices=['concat', 'mean', 'wmean']) + parser.add_argument('--resolution', default=1024, choices=[512, 1024, 2048, 4096, 8192], type=int) + parser.add_argument('--prompt', type=str, default='p1', choices=['p1', 'p2', 'p3', 'p4']) + parser.add_argument('--seed', type=int, default=1) + parser.add_argument('--renorm', action="store_true", default=False) + parser.add_argument('--num_inference_steps', type=int, default=50) + args = parser.parse_args() + return args + +def get_roll(x): + h, w = x.size(-2), x.size(-1) + dh, dw = random.randint(0,h), random.randint(0,w) + return dh, dw + +def patch(x, k): + n, c, h, w = x.shape + x_ = x.view(-1,k*k,c*h*w).transpose(1,-1) # (n, c*h*w, k*k) + folded = F.fold(x_, output_size=(h*k,w*k), kernel_size=(h,w), stride=(h,w)) # (n, c, h*k, w*k) + return folded + +def unpatch(x, k, p=0): + n, c, kh, kw = x.shape + h, w = (kh-2*p)//k, (kw-2*p)//k + x_ = F.unfold(x, kernel_size=(h+2*p,w+2*p), stride=(h,w)) # (n, c*[h+2p]*[w+2p], k*k) + unfolded = x_.transpose(1,2).reshape(-1,c,64+2*p,64+2*p) # (n*k*k, c, h+2p, w+2p) + return unfolded + +def get_kernel(p, device): + x1, x2 = 512-1, 512+2*p-1 + y1, y2 = 1, 0 + fun = lambda x: (y1-y2)/(x1-x2)*x + (x1*y2-x2*y1)/(x1-x2) + x = torch.arange(512+2*p, device=device) + y = fun(x) + y[:512]=1 + y += y.flip(0) + y -= 1 + Y = torch.outer(y,y) + return Y[None][None] + +def get_lora_sd_pipeline( + ckpt_dir, base_model_name_or_path=None, dtype=torch.float16, device="cuda", adapter_name="default" +): + from peft import PeftModel, LoraConfig + from diffusers import StableDiffusionPipeline + + unet_sub_dir = os.path.join(ckpt_dir, "unet") + text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder") + + if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None: + config = LoraConfig.from_pretrained(text_encoder_sub_dir) + base_model_name_or_path = config.base_model_name_or_path + + if base_model_name_or_path is None: + raise ValueError("Please specify the base model name or path") + + pipe = StableDiffusionPipeline.from_pretrained( + base_model_name_or_path, + torch_dtype=dtype, + local_files_only=True, + safety_checker=None, + ).to(device) + + pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name) + + if os.path.exists(text_encoder_sub_dir): + pipe.text_encoder = PeftModel.from_pretrained( + pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name + ) + + if dtype in (torch.float16, torch.bfloat16): + pipe.unet.half() + pipe.text_encoder.half() + + pipe.to(device) + return pipe + +def get_vanilla_sd_pipeline(device='cuda'): + from diffusers import StableDiffusionPipeline + + pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="fp16", + torch_dtype=torch.float16, + local_files_only=True, + safety_checker=None, + ) + pipe.to(device) + return pipe + +@torch.no_grad() +def main(args): + if args.token is None: + assert args.path.is_dir() + # global_step =j f"{args.path.name.split('-')[-1]:0>4}" + + token = 'azertyuiop' + print(f'loading LoRA with token {token}') + pipe = get_lora_sd_pipeline(ckpt_dir=Path(args.path)) + else: + token = args.token + pipe = get_vanilla_sd_pipeline() + print(f'picked token={token}') + + v_token = token + + prompt = dict( + p1='top view realistic texture of {}', + p2='top view realistic {} texture', + p3='high resolution realistic {} texture in top view', + p4='realistic {} texture in top view', + )[args.prompt] + print(f'{args.prompt} => {prompt}') + v_prompt = prompt.replace(' ', '-').format('o') + prompt = prompt.format(token) + + # negative_prompt = "lowres, error, cropped, worst quality, low quality, jpeg artifacts, out of frame, watermark, signature, illustration, painting, drawing, art, sketch" + negative_prompt = "" + generator = torch.Generator("cuda").manual_seed(args.seed) + random.seed(args.seed) + + if args.path is not None: + outdir = args.path/'outputs' + print(f'ignoring `args.outdir` and using path {outdir}') + outdir.mkdir(exist_ok=True) + else: + # ckpt_dir + outdir = args.outdir + + reso = {512: 'hK', 1024: '1K', 2048: '2K', 4096: '4K', 8192: '8K'}[args.resolution] + fname = outdir/f'{v_token}_{reso}_t{args.num_inference_steps}_{args.stitch_mode}_{v_prompt}_{args.seed}.png' + + if fname.exists(): + print('already exists!') + return fname + print(f'preparing for {fname}') + + ################################################################################################ + # Inference code + ################################################################################################ + k= (args.resolution//512) + + num_images_per_prompt=1 + guidance_scale=7.5 + # guidance_scale=1.0 + + callback_steps=1 + cross_attention_kwargs=None + # clip_skip=None + num_inference_steps=args.num_inference_steps + eta=0.0 + guidance_rescale=0.0 + callback=None + callback_steps=1 + output_type='pil' + height=None + width=None + latents=None + prompt_embeds=None + negative_prompt_embeds=None + + height = height or pipe.unet.config.sample_size * pipe.vae_scale_factor + width = width or pipe.unet.config.sample_size * pipe.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + pipe.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = pipe._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = pipe._encode_prompt( + prompt, + device, + num_images_per_prompt*k*k, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + pipe.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = pipe.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = pipe.unet.config.in_channels + latents = pipe.prepare_latents( + (batch_size * num_images_per_prompt)*k*k, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order + with pipe.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t) + + # roll noise + kx, ky = get_roll(latent_model_input) + latent_model_input = patch(latent_model_input, k) + latent_model_input = latent_model_input.roll((kx, ky), dims=(2,3)) + latent_model_input = unpatch(latent_model_input, k) + + # split in two for inference + noise_pred = [] + chunk_size = len(latent_model_input)//16 or 1 + for latent_chunk, prompt_chunk \ + in zip(latent_model_input.chunk(chunk_size), prompt_embeds.chunk(chunk_size)): + # predict the noise residual + res = pipe.unet(latent_chunk, t, encoder_hidden_states=prompt_chunk) + noise_pred.append(res.sample) + noise_pred = torch.cat(noise_pred) + + # noise unrolling + noise_pred = patch(noise_pred, k) + noise_pred = noise_pred.roll((-kx, -ky), dims=(2,3)) + noise_pred = unpatch(noise_pred, k) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0): + progress_bar.update() + + + if args.resolution == 512: + decoded = pipe.vae.decode(patch(latents, k) / pipe.vae.config.scaling_factor) + decoded = decoded.sample.detach().cpu().double() + images = pipe.image_processor.postprocess(decoded, output_type='pil', do_denormalize=[True]*len(decoded)) + images[0].save(fname) + + ## stiching part + if args.stitch_mode == 'concat': # naive concatenation + # image = pipe.vae.decode(folded / pipe.vae.config.scaling_factor) + chunk_size = len(latents)//16 or 1 + out = [] + for chunk in latents.chunk(chunk_size): + image = pipe.vae.decode(chunk / pipe.vae.config.scaling_factor) + out.append(image.sample.detach().cpu().double()) + out = torch.cat(out) + + images = pipe.image_processor.postprocess(out, output_type='pt', do_denormalize=[True]*len(out)) + save_image(images, fname, nrow=k, padding=0) + # [img.save(f'{i}.png') for i, img in enumerate(images)] + + elif args.stitch_mode == 'mean': # patch mean blending + p=1 + folded = patch(latents, k) + folded_padded = F.pad(folded, pad=(p,p,p,p), mode='circular') + unfolded_padded = unpatch(folded_padded, k, p) + + chunk_size = len(unfolded_padded)//16 or 1 + image_stack = [] + for chunk in unfolded_padded.chunk(chunk_size): + image = pipe.vae.decode(chunk / pipe.vae.config.scaling_factor) + image_stack.append(image.sample) + image_stack = torch.cat(image_stack) + + lmean = image_stack.mean(dim=(-1,-2), keepdim=True) + gmean = image_stack.mean(dim=(0,2,3), keepdim=True) + image_stack = image_stack*gmean/lmean + + # with a naive average stitching, the overlap values (bands) are divided + s = pipe.vae_scale_factor # 1:8 in pixel space + tp = 2*s*p # total padding + image_stack[:,:,:tp,:] /= 2. + image_stack[:,:,-tp:,:] /= 2. + image_stack[:,:,:,:tp] /= 2. + image_stack[:,:,:,-tp:] /= 2. + + # gather values into final tensor + _, c, hpad, wpad = image_stack.shape + h, w = hpad-tp, wpad-tp + out_padded = torch.zeros(batch_size, c, h*k+tp, w*k+tp, device=image_stack.device) + for i, j in product(range(k), range(k)): + out_padded[:,:,h*i:w*(i+1)+tp,h*j:w*(j+1)+tp] += image_stack[None,i*k+j] + + # accumulate outer bands to opposite sides: + hp = s*p # half padding + out_padded[:,:,-tp:-hp,:] += out_padded[:,:,:hp,:] + out_padded[:,:,hp:tp,:] += out_padded[:,:,-hp:,:] + out_padded[:,:,:,-tp:-hp] += out_padded[:,:,:,:hp] + out_padded[:,:,:,hp:tp] += out_padded[:,:,:,-hp:] + + out = out_padded[:,:,hp:-hp,hp:-hp] # trim + image, *_ = pipe.image_processor.postprocess(out, output_type='pil', do_denormalize=[True]) + image.save(fname) + + elif args.stitch_mode == 'wmean': # weighted average kernel blending + p=1 + folded = patch(latents, k) + folded_padded = F.pad(folded, pad=(p,p,p,p), mode='circular') + unfolded_padded = unpatch(folded_padded, k, p) + + chunk_size = len(unfolded_padded)//16 or 1 + image_stack = [] + for chunk in unfolded_padded.chunk(chunk_size): + image = pipe.vae.decode(chunk / pipe.vae.config.scaling_factor) + image_stack.append(image.sample) + image_stack = torch.cat(image_stack) + + # lmean = image_stack.mean(dim=(-1,-2), keepdim=True) + # gmean = image_stack.mean(dim=(0,2,3), keepdim=True) + # image_stack = image_stack*gmean/lmean + + ## patch blending + scale = pipe.vae_scale_factor + tp = 2*scale*p # total padding + mask = get_kernel(scale*p, image_stack.device) # 1:8 in pixel space + # import pdb; pdb.set_trace() + # print(mask.shape) + image_stack *= mask + + # gather values into final tensor + _, c, hpad, wpad = image_stack.shape + h, w = hpad-tp, wpad-tp + out_padded = torch.zeros(batch_size, c, h*k+tp, w*k+tp, device=image_stack.device) + for i, j in product(range(k), range(k)): + out_padded[:,:,h*i:w*(i+1)+tp,h*j:w*(j+1)+tp] += image_stack[None,i*k+j] + + # accumulate outer bands to opposite sides: + hp = scale*p # half padding + out_padded[:,:,-tp:-hp,:] += out_padded[:,:,:hp,:] + out_padded[:,:,hp:tp,:] += out_padded[:,:,-hp:,:] + out_padded[:,:,:,-tp:-hp] += out_padded[:,:,:,:hp] + out_padded[:,:,:,hp:tp] += out_padded[:,:,:,-hp:] + + out = out_padded[:,:,hp:-hp,hp:-hp] # trim + image, *_ = pipe.image_processor.postprocess(out, output_type='pil', do_denormalize=[True]) + image.save(fname) + + if args.renorm: + from . import renorm + renorm(fname) + + return fname + +def infer(path, outdir=None, stitch_mode='wmean', renorm=False, resolution=1024, seed=1, prompt='p1', num_inference_steps=50): + return main(Namespace( + path=path, + outdir=outdir, + prompt=prompt, + token=None, + renorm=renorm, + stitch_mode=stitch_mode, + resolution=resolution, + seed=seed, + num_inference_steps=num_inference_steps)) + +if __name__ == "__main__": + args = parse_args() + print(args) + main(args) diff --git a/concept/invert.py b/concept/invert.py new file mode 100644 index 0000000..f26aca0 --- /dev/null +++ b/concept/invert.py @@ -0,0 +1,278 @@ +# Original source code: https://github.com/huggingface/peft +# The code is taken from examples/lora_dreambooth/train_dreambooth.py and performs LoRA Dreambooth +# finetuning, it was modified for integration in the Material Palette pipeline. It includes some +# minor modifications but is heavily refactored and commented to make it more digestible and clear. +# It is rather self-contained but avoids being +1000 lines long! The code has two interfaces: +# the original CLI and a functinal interface via `invert()`, they have the same default parameters! + +import os +import math +import itertools +from pathlib import Path +from argparse import Namespace + +import torch +from tqdm.auto import tqdm +import torch.utils.checkpoint +import torch.nn.functional as F +from accelerate.utils import set_seed +from diffusers.utils import check_min_version + +from concept.args import parse_args +from concept.utils import load_models, load_optimizer, load_logger, load_scheduler, load_dataloader, save_lora + +# Will throw error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.10.0.dev0") + + +def main(args): + from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler + + set_seed(args.seed) + + root_dir = Path(args.data_dir) + assert root_dir.is_dir() + + output_dir = args.prompt.replace(' ', '_') + output_dir = root_dir.parent.parent / 'weights' / root_dir.name / output_dir + output_dir.mkdir(exist_ok=True, parents=True) + + ckpt_path = output_dir / f'checkpoint-{args.max_train_steps}' / 'text_encoder' + if ckpt_path.is_dir(): + print(f'{ckpt_path} already exists') + return ckpt_path.parent + + if args.validation_prompt is not None: + output_dir_val = output_dir/'val' + output_dir_val.mkdir() + + ## Load dataset (earliest as possible to anticipate crashes) + train_dataset, train_dataloader = load_dataloader(args, root_dir) + + from accelerate import Accelerator # avoid preloading before directory validation + accelerator = Accelerator( + # gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_dir=output_dir / args.logging_dir, + ) + + logger = load_logger(args, accelerator) + + ## Load scheduler and models + noise_scheduler, text_encoder, vae, unet = load_models(args) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + optimizer, args.learning_rate = load_optimizer(args, unet, text_encoder, accelerator.num_processes) + + lr_scheduler = load_scheduler(args, optimizer) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae and text_encoder to device and cast to weight_dtype + vae.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # Initialize the trackers we use and store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes + + ##! remove args.num_train_epochs and args.gradient_accumulation_steps from CLI + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed) = {total_batch_size}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(output_dir, path)) + global_step = int(path.split("-")[1]) + + first_epoch = global_step // len(train_dataloader) + resume_step = global_step % len(train_dataloader) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader)) + for epoch in range(first_epoch, num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder.train() + + for step, (img, prompt) in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + progress_bar.update(1) + if args.report_to == "wandb": + accelerator.print(progress_bar) + continue + + # Embed the images to latent space and apply scale factor + latents = vae.encode(img.to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample a random timestep for each image + T = noise_scheduler.config.num_train_timesteps + timesteps = torch.randint(0, T, (len(latents),), device=latents.device, dtype=torch.long) + + # Forward diffusion process: add noise to the latents according to the noise magnitude + noise = torch.randn_like(latents) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(prompt)[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # L2 error reconstruction objective + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Backward pass on denoiser and optionnally text encoder + accelerator.backward(loss) + + # Gradient clipping step + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + # Handle optimzer and learning rate scheduler + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + if args.report_to == "wandb": + accelerator.print(progress_bar) + global_step += 1 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + _text_encoder = text_encoder if args.train_text_encoder else None + save_lora(accelerator, unet, _text_encoder, output_dir, global_step) + + # Log loss and learning rates + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + # Validation step + if (args.validation_prompt is not None) and (global_step % args.validation_steps == 0): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # Create pipeline for validation pass + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + safety_checker=None, + revision=args.revision, + local_files_only=True) + + # Set `keep_fp32_wrapper` to True because we do not want to remove + # mixed precision hooks while we are still training + pipeline.unet = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) + pipeline.text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # Set sampler generator seed + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + # Run inference + for i in range(args.num_validation_images): + image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + image.save(output_dir / 'val' / f'{global_step}_{i}.png') + + del pipeline + torch.cuda.empty_cache() + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + _text_encoder = text_encoder if args.train_text_encoder else None + save_lora(accelerator, unet, _text_encoder, output_dir, global_step) + + accelerator.end_training() + return ckpt_path.parent + +DEFAULT_PROMPT = "an object with azertyuiop texture" +def invert(data_dir: str, prompt=DEFAULT_PROMPT, train_text_encoder=True, gradient_checkpointing=True, **kwargs) -> Path: + """ + Functional interface for the inversion step of the method. It adopts the same interface as + the CLI defined in `args.py` by `parse_args` (jump there for details). If the region has already + been inverted the function will exit early. Always returns the path of the inversion checkpoint. + + :param str `data_dir`: path of the directory containing the region crops to invert + :param str `prompt`: prompt used for inversion containing the rare token eg. "an object with zkjefb texture" + :return Path: the path to the inversion checkpoint + """ + all_args = parse_args(return_defaults=True) + all_args.update(data_dir=str(data_dir), + prompt=prompt, + train_text_encoder=train_text_encoder, + gradient_checkpointing=gradient_checkpointing, + **kwargs) + return main(Namespace(**all_args)) + +if __name__ == "__main__": + args = parse_args() + args.train_text_encoder = True + args.gradient_checkpointing = True + main(args) diff --git a/concept/renorm.py b/concept/renorm.py new file mode 100644 index 0000000..0f14e9e --- /dev/null +++ b/concept/renorm.py @@ -0,0 +1,62 @@ +import argparse, csv, random +from pathlib import Path +from PIL import Image + +import numpy as np +import cv2 +import torch +from tqdm import tqdm +import torchvision.transforms.functional as tf +from torchvision.utils import save_image +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + + +def renorm(path): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + renorm_dir = path.parent.parent/'out_renorm' + proj_dir = Path(*path.parts[:-6]) + + ## Get input rgb image + proposals = [x for x in proj_dir.iterdir() if x.is_file() and x.suffix in ('.jpg', '.png', '.jpeg')] + assert len(proposals) == 1 + img_path = proposals[0] + pil_img = Image.open(img_path) + color = tf.to_tensor(pil_img.convert('RGB')) + + ## Get region mask + mask_path = proj_dir/'masks'/f'{path.parts[-5]}.png' + assert mask_path.is_file() + mask = tf.to_tensor(Image.open(mask_path).convert('L')) + mask = tf.resize(mask, size=(pil_img.height, pil_img.width))[0] + + mask = mask == 1. + grayscale = tf.to_tensor(pil_img.convert('L'))[0] + gray_flat = grayscale[mask] + + # Flatten the grayscale and sort pixels + sorted_pixels, _ = gray_flat.sort() + exclude_count = int(0.005 * len(gray_flat)) + low_threshold = sorted_pixels[exclude_count] + high_threshold = sorted_pixels[-exclude_count] + + # construct the mask + m = (gray_flat >= low_threshold) & (gray_flat <= high_threshold) + + ref_flatten = color[:,mask] + ref = torch.stack([ref_flatten[0, m], ref_flatten[1, m], ref_flatten[2, m]]) + mean_ref = ref.mean(1)[:,None,None].to(device) + std_ref = ref.std(1)[:,None,None].to(device) + + # gather patches + renorm_dir.mkdir(exist_ok=True) + x = tf.to_tensor(Image.open(path))[None].to(device) + mean = x.mean(dim=(2,3),keepdim=True) + std = x.std(dim=(2,3),keepdim=True) + + # renormalize + x = (x-mean)/std * std_ref + mean_ref + x.clamp_(0,1) + + s_out = renorm_dir/path.name + tf.to_pil_image(x[0]).save(s_out) \ No newline at end of file diff --git a/concept/utils.py b/concept/utils.py new file mode 100644 index 0000000..b649308 --- /dev/null +++ b/concept/utils.py @@ -0,0 +1,196 @@ +import itertools +import logging + +import torch +import datasets +import diffusers +import transformers +from transformers import CLIPTextModel +from transformers import AutoTokenizer +from torch.utils.data import DataLoader +from accelerate.logging import get_logger +from peft import LoraConfig, get_peft_model +from diffusers.optimization import get_scheduler +from diffusers.utils.import_utils import is_xformers_available +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel + +from .data import DreamBoothDataset + + +def load_models(args): + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000) + + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision) + + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + elif args.train_text_encoder and args.use_lora: + config = LoraConfig( + r=args.lora_text_encoder_r, + lora_alpha=args.lora_text_encoder_alpha, + target_modules=["q_proj", "v_proj"], + lora_dropout=args.lora_text_encoder_dropout, + bias=args.lora_text_encoder_bias) + text_encoder = get_peft_model(text_encoder, config) + text_encoder.print_trainable_parameters() + + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision) + vae.requires_grad_(False) + + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision) + + if args.use_lora: + config = LoraConfig( + r=args.lora_r, + lora_alpha=args.lora_alpha, + target_modules= ["to_q", "to_v", "query", "value"], + lora_dropout=args.lora_dropout, + bias=args.lora_bias) + unet = get_peft_model(unet, config) + unet.print_trainable_parameters() + + ## advanced options + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + # below fails when using lora so commenting it out + if args.train_text_encoder and not args.use_lora: + text_encoder.gradient_checkpointing_enable() + + return noise_scheduler, text_encoder, vae, unet + + +def load_optimizer(args, unet, text_encoder, num_processes): + if args.scale_lr: + lr = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * num_processes + else: + lr = args.learning_rate + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + + if args.train_text_encoder: + params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters()) + else: + params_to_optimize = unet.parameters() + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon) + + return optimizer, lr + + +def load_logger(args, accelerator): + # Load wandb if needed + if args.report_to == 'wandb': + import wandb + wandb.login(key=args.wandb_key) + wandb.init(project=args.wandb_project_name) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + + logger = get_logger(__name__) + logger.info(accelerator.state, main_process_only=False) + + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + return logger + + +def load_scheduler(args, optimizer): + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power) + return lr_scheduler + + +def save_lora(accelerator, unet, text_encoder, output_dir, global_step): + ckpt_dir = output_dir / f'checkpoint-{global_step}' + ckpt_dir.mkdir(exist_ok=True) + + unwrapped_unet = accelerator.unwrap_model(unet) + unet_dir = ckpt_dir / 'unet' + unwrapped_unet.save_pretrained(unet_dir, state_dict=accelerator.get_state_dict(unet)) + + if text_encoder: + unwrapped_text_encoder = accelerator.unwrap_model(text_encoder) + textenc_dir = ckpt_dir / 'text_encoder' + textenc_state = accelerator.get_state_dict(text_encoder) + unwrapped_text_encoder.save_pretrained(textenc_dir, state_dict=textenc_state) + + +def load_dataloader(args, root_dir): + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, + revision=args.revision, + use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False) + + train_dataset = DreamBoothDataset( + data_dir=root_dir, + prompt=args.prompt, + tokenizer=tokenizer, + size=args.resolution) + + train_dataloader = DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + num_workers=1) + + return train_dataset, train_dataloader \ No newline at end of file diff --git a/deps.yml b/deps.yml new file mode 100644 index 0000000..12b152e --- /dev/null +++ b/deps.yml @@ -0,0 +1,17 @@ +# install: conda env create -f deps.yml +name: matpal +channels: + - pytorch + - nvidia +dependencies: + - pytorch==1.13.1 + - torchvision==0.14.1 + - pytorch-cuda=11.7 + - pip + - pip: + - lightning==1.8.3 + - diffusers==0.19.3 + - peft==0.5.0 + - opencv_python + - jsonargparse + - easydict \ No newline at end of file diff --git a/pipeline.py b/pipeline.py new file mode 100644 index 0000000..76ace15 --- /dev/null +++ b/pipeline.py @@ -0,0 +1,29 @@ +from pathlib import Path +from argparse import ArgumentParser + +from pytorch_lightning import Trainer + +import concept +import capture + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('path', type=Path) + args = parser.parse_args() + + ## Extract square crops from image for each of the binary masks located in /masks + regions = concept.crop(args.path) + + ## Iterate through regions to invert the concept and generate texture views + for region in regions.iterdir(): + lora = concept.invert(region) + concept.infer(lora, renorm=True) + + ## Construct a dataset with all generations and load pretrained decomposition model + data = capture.get_data(predict_dir=args.path, predict_ds='sd') + module = capture.get_inference_module(pt='model.ckpt') + + ## Proceed with inference on decomposition model + decomp = Trainer(default_root_dir=args.path, accelerator='gpu', devices=1, precision=16) + decomp.predict(module, data) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..790d31b --- /dev/null +++ b/train.py @@ -0,0 +1,19 @@ +from pytorch_lightning import seed_everything +from capture.utils import Trainer, get_args, get_module, get_name, get_data + + +if __name__ == '__main__': + args = get_args() + + seed_everything(seed=args.seed) + + data = get_data(args.data) + module = get_module(args) + + args.name = get_name(args) + args.out_dir = args.out_dir/name + callbacks = get_callbacks(args) + logger = get_logger(args) + + trainer = Trainer(args, default_root_dir=out_dir, logger=logger, callbacks=callbacks, **args.trainer) + trainer(args.mode, module, data) \ No newline at end of file