-
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit a13c4e6
Showing
22 changed files
with
402 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
root = true | ||
|
||
[*] | ||
end_of_line = lf | ||
insert_final_newline = true | ||
indent_size = 4 | ||
indent_style = tab | ||
trim_trailing_whitespace = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
[flake8] | ||
select = E3, E4, F, I1, I2 | ||
plugins = flake8-import-order | ||
application_import_names = arcface_converter | ||
import-order-style = pycharm | ||
per-file-ignores = preparing.py:E402 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
github: henryruhs | ||
custom: [ buymeacoffee.com/henryruhs, paypal.me/henryruhs ] |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
name: ci | ||
|
||
on: [ push, pull_request ] | ||
|
||
jobs: | ||
lint: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Checkout | ||
uses: actions/checkout@v4 | ||
- name: Set up Python 3.10 | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: '3.10' | ||
- run: pip install flake8 | ||
- run: pip install flake8-import-order | ||
- run: pip install mypy | ||
- run: flake8 arcface_converter | ||
- run: mypy arcface_converter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
.idea | ||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
MIT license | ||
|
||
Copyright (c) 2024 Henry Ruhs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
FaceFusion Labs | ||
=============== | ||
|
||
> Industry leading face manipulation platform. | ||
[![Build Status](https://img.shields.io/github/actions/workflow/status/facefusion/facefusion-labs/ci.yml.svg?branch=master)](https://github.com/facefusion/facefusion-labs/actions?query=workflow:ci) | ||
![License](https://img.shields.io/badge/license-MIT-green) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
ArcFace Converter | ||
================= | ||
|
||
> Convert face embeddings between various ArcFace models. | ||
|
||
Preview | ||
------- | ||
|
||
![Preview](https://raw.githubusercontent.com/facefusion/facefusion-labs/master/.github/preview_arcface_converter.png?sanitize=true) | ||
|
||
|
||
Preparing | ||
--------- | ||
|
||
``` | ||
python prepare.py | ||
``` | ||
|
||
|
||
Training | ||
-------- | ||
|
||
``` | ||
python train.py | ||
``` | ||
|
||
|
||
Exporting | ||
--------- | ||
|
||
``` | ||
python export.py | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
[preparing.dataset] | ||
dataset_path = datasets/dataset/train.rec | ||
crop_size = 112 | ||
process_limit = 650000 | ||
|
||
[preparing.model] | ||
source_path = models/arcface_source.onnx | ||
target_path = models/arcface_target.onnx | ||
|
||
[preparing.input] | ||
directory_path = inputs | ||
source_path = inputs/arcface_source.npy | ||
target_path = inputs/arcface_target.npy | ||
|
||
[training.loader] | ||
split_ratio = 0.8 | ||
batch_size = 51200 | ||
num_workers = 8 | ||
|
||
[training.trainer] | ||
max_epochs = 4096 | ||
|
||
[training.output] | ||
directory_path = outputs | ||
file_pattern = arcface_converter_{epoch:02d}_{val_loss:.4f} | ||
|
||
[exporting] | ||
directory_path = exports | ||
source_path = outputs/last.ckpt | ||
target_path = exports/arcface_converter.onnx | ||
opset_version = 15 | ||
|
||
[execution] | ||
providers = CUDAExecutionProvider |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/usr/bin/env python3 | ||
|
||
from src.exporting import export | ||
|
||
if __name__ == '__main__': | ||
export() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/usr/bin/env python3 | ||
|
||
from src.preparing import prepare | ||
|
||
if __name__ == '__main__': | ||
prepare() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import configparser | ||
from os import makedirs | ||
|
||
import torch | ||
|
||
from .training import ArcFaceConverterTrainer | ||
|
||
CONFIG = configparser.ConfigParser() | ||
CONFIG.read('config.ini') | ||
|
||
|
||
def export() -> None: | ||
directory_path = CONFIG.get('exporting', 'directory_path') | ||
source_path = CONFIG.get('exporting', 'source_path') | ||
target_path = CONFIG.get('exporting', 'target_path') | ||
opset_version = CONFIG.getint('exporting', 'opset_version') | ||
|
||
makedirs(directory_path, exist_ok = True) | ||
model = ArcFaceConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu') | ||
model.eval() | ||
input_tensor = torch.randn(1, 512) | ||
torch.onnx.export(model, input_tensor, target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = opset_version) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch import Tensor | ||
|
||
|
||
class ArcFaceConverter(nn.Module): | ||
def __init__(self) -> None: | ||
super(ArcFaceConverter, self).__init__() | ||
self.fc1 = nn.Linear(512, 1024) | ||
self.fc2 = nn.Linear(1024, 2048) | ||
self.fc3 = nn.Linear(2048, 1024) | ||
self.fc4 = nn.Linear(1024, 512) | ||
self.activation = nn.LeakyReLU() | ||
|
||
def forward(self, inputs : Tensor) -> Tensor: | ||
norm_inputs = inputs / torch.norm(inputs) | ||
outputs = self.activation(self.fc1(norm_inputs)) | ||
outputs = self.activation(self.fc2(outputs)) | ||
outputs = self.activation(self.fc3(outputs)) | ||
outputs = self.fc4(outputs) | ||
return outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import configparser | ||
from os import makedirs | ||
from os.path import isfile | ||
from typing import List | ||
|
||
import numpy | ||
numpy.bool = numpy.bool_ | ||
from mxnet.io import ImageRecordIter | ||
from onnxruntime import InferenceSession | ||
from tqdm import tqdm | ||
|
||
from .typing import Embedding, EmbeddingPairs, VisionFrame | ||
|
||
CONFIG = configparser.ConfigParser() | ||
CONFIG.read('config.ini') | ||
|
||
|
||
def prepare_crop_vision_frame(crop_vision_frame : VisionFrame) -> VisionFrame: | ||
crop_vision_frame = crop_vision_frame.astype(numpy.float32) / 255 | ||
crop_vision_frame = (crop_vision_frame - 0.5) * 2 | ||
return crop_vision_frame | ||
|
||
|
||
def create_inference_session(model_path : str, execution_providers : List[str]) -> InferenceSession: | ||
inference_session = InferenceSession(model_path, providers = execution_providers) | ||
return inference_session | ||
|
||
|
||
def forward(inference_session : InferenceSession, crop_vision_frame : VisionFrame) -> Embedding: | ||
embedding = inference_session.run(None, | ||
{ | ||
'input': crop_vision_frame | ||
})[0] | ||
|
||
return embedding | ||
|
||
|
||
def process_embeddings(dataset_reader : ImageRecordIter, source_inference_session : InferenceSession, target_inference_session : InferenceSession) -> EmbeddingPairs: | ||
dataset_process_limit = CONFIG.getint('preparing.dataset', 'process_limit') | ||
embedding_pairs = [] | ||
|
||
with tqdm(total = dataset_process_limit) as progress: | ||
for batch in dataset_reader: | ||
crop_vision_frame = batch.data[0].asnumpy() | ||
crop_vision_frame = prepare_crop_vision_frame(crop_vision_frame) | ||
source_embedding = forward(source_inference_session, crop_vision_frame) | ||
target_embedding = forward(target_inference_session, crop_vision_frame) | ||
embedding_pairs.append([ source_embedding, target_embedding ]) | ||
progress.update() | ||
|
||
if progress.n == dataset_process_limit: | ||
return numpy.concatenate(embedding_pairs, axis = 1).T | ||
|
||
return numpy.concatenate(embedding_pairs, axis = 1).T | ||
|
||
|
||
def prepare() -> None: | ||
dataset_path = CONFIG.get('preparing.dataset', 'dataset_path') | ||
dataset_crop_size = CONFIG.getint('preparing.dataset', 'crop_size') | ||
model_source_path = CONFIG.get('preparing.model', 'source_path') | ||
model_target_path = CONFIG.get('preparing.model', 'target_path') | ||
input_directory_path = CONFIG.get('preparing.input', 'directory_path') | ||
input_source_path = CONFIG.get('preparing.input', 'source_path') | ||
input_target_path = CONFIG.get('preparing.input', 'target_path') | ||
execution_providers = CONFIG.get('execution', 'providers').split(' ') | ||
|
||
makedirs(input_directory_path, exist_ok = True) | ||
if isfile(dataset_path) and isfile(model_source_path) and isfile(model_target_path): | ||
dataset_reader = ImageRecordIter( | ||
path_imgrec = dataset_path, | ||
data_shape = (3, dataset_crop_size, dataset_crop_size), | ||
batch_size = 1, | ||
shuffle = False | ||
) | ||
source_inference_session = create_inference_session(model_source_path, execution_providers) | ||
target_inference_session = create_inference_session(model_target_path, execution_providers) | ||
embedding_pairs = process_embeddings(dataset_reader, source_inference_session, target_inference_session) | ||
numpy.save(input_source_path, embedding_pairs[..., 0].T) | ||
numpy.save(input_target_path, embedding_pairs[..., 1].T) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import configparser | ||
from typing import Any, Tuple | ||
|
||
import numpy | ||
import pytorch_lightning | ||
import torch | ||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.callbacks import ModelCheckpoint | ||
from pytorch_lightning.tuner.tuning import Tuner | ||
from torch import Tensor | ||
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split | ||
|
||
from .model import ArcFaceConverter | ||
from .typing import Batch, Loader | ||
|
||
CONFIG = configparser.ConfigParser() | ||
CONFIG.read('config.ini') | ||
|
||
|
||
class ArcFaceConverterTrainer(pytorch_lightning.LightningModule): | ||
def __init__(self) -> None: | ||
super(ArcFaceConverterTrainer, self).__init__() | ||
self.model = ArcFaceConverter() | ||
self.loss_fn = torch.nn.MSELoss() | ||
self.lr = 0.001 | ||
|
||
def forward(self, source_embedding : Tensor) -> Tensor: | ||
return self.model(source_embedding) | ||
|
||
def training_step(self, batch : Batch, batch_index : int) -> Tensor: | ||
source, target = batch | ||
output = self(source) | ||
loss = self.loss_fn(output, target) | ||
self.log('train_loss', loss, prog_bar = True, logger = True) | ||
return loss | ||
|
||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor: | ||
source, target = batch | ||
output = self(source) | ||
loss = self.loss_fn(output, target) | ||
self.log('val_loss', loss, prog_bar = True, logger = True) | ||
return loss | ||
|
||
def configure_optimizers(self) -> Any: | ||
optimizer = torch.optim.Adam(self.parameters(), lr = self.lr) | ||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) | ||
|
||
return\ | ||
{ | ||
'optimizer': optimizer, | ||
'lr_scheduler': | ||
{ | ||
'scheduler': scheduler, | ||
'monitor': 'train_loss', | ||
'interval': 'epoch', | ||
'frequency': 1 | ||
} | ||
} | ||
|
||
|
||
def create_loaders() -> Tuple[Loader, Loader]: | ||
loader_batch_size = CONFIG.getint('training.loader', 'batch_size') | ||
loader_num_workers = CONFIG.getint('training.loader', 'num_workers') | ||
|
||
training_dataset, validate_dataset = split_dataset() | ||
training_loader = DataLoader(training_dataset, batch_size = loader_batch_size, num_workers = loader_num_workers, shuffle = True, pin_memory = True) | ||
validation_loader = DataLoader(validate_dataset, batch_size = loader_batch_size, num_workers = loader_num_workers, shuffle = False, pin_memory = True) | ||
return training_loader, validation_loader | ||
|
||
|
||
def split_dataset() -> Tuple[Dataset[Any], Dataset[Any]]: | ||
input_source_path = CONFIG.get('preparing.input', 'source_path') | ||
input_target_path = CONFIG.get('preparing.input', 'target_path') | ||
loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio') | ||
|
||
source_input = torch.from_numpy(numpy.load(input_source_path)).float() | ||
target_input = torch.from_numpy(numpy.load(input_target_path)).float() | ||
dataset = TensorDataset(source_input, target_input) | ||
|
||
dataset_size = len(dataset) | ||
training_size = int(loader_split_ratio * len(dataset)) | ||
validation_size = int(dataset_size - training_size) | ||
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ]) | ||
return training_dataset, validate_dataset | ||
|
||
|
||
def create_trainer() -> Trainer: | ||
trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs') | ||
output_directory_path = CONFIG.get('training.output', 'directory_path') | ||
output_file_pattern = CONFIG.get('training.output', 'file_pattern') | ||
|
||
return Trainer( | ||
max_epochs = trainer_max_epochs, | ||
callbacks = | ||
[ | ||
ModelCheckpoint( | ||
monitor = 'train_loss', | ||
dirpath = output_directory_path, | ||
filename = output_file_pattern, | ||
every_n_epochs = 10, | ||
save_top_k = 3, | ||
save_last = True | ||
) | ||
], | ||
enable_progress_bar = True, | ||
log_every_n_steps = 2 | ||
) | ||
|
||
|
||
def train() -> None: | ||
trainer = create_trainer() | ||
training_loader, validation_loader = create_loaders() | ||
model = ArcFaceConverterTrainer() | ||
tuner = Tuner(trainer) | ||
tuner.lr_find(model, training_loader, validation_loader) | ||
trainer.fit(model, training_loader, validation_loader) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from typing import Any, Tuple | ||
|
||
from numpy.typing import NDArray | ||
from torch import Tensor | ||
from torch.utils.data import DataLoader | ||
|
||
Batch = Tuple[Tensor, Tensor] | ||
Loader = DataLoader[Tuple[Tensor, ...]] | ||
|
||
Embedding = NDArray[Any] | ||
EmbeddingPairs = NDArray[Any] | ||
FaceLandmark5 = NDArray[Any] | ||
VisionFrame = NDArray[Any] |
Oops, something went wrong.