Skip to content

Commit

Permalink
Merge pull request kohya-ss#1285 from ccharest93/main
Browse files Browse the repository at this point in the history
Hyperparameter tracking
  • Loading branch information
kohya-ss authored May 19, 2024
2 parents e3ddd1f + b886d0a commit 47187f7
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 9 deletions.
2 changes: 1 addition & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)

# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
Expand Down
27 changes: 27 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3388,6 +3388,33 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser):
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
)

def filter_sensitive_args(args: argparse.Namespace):
sensitive_args = ["wandb_api_key", "huggingface_token"]
sensitive_path_args = [
"pretrained_model_name_or_path",
"vae",
"tokenizer_cache_dir",
"train_data_dir",
"conditioning_data_dir",
"reg_data_dir",
"output_dir",
"logging_dir",
]
filtered_args = {}
for k, v in vars(args).items():
# filter out sensitive values
if k not in sensitive_args + sensitive_path_args:
#Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`.
if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int):
filtered_args[k] = v
# accelerate does not support lists
elif isinstance(v, list):
filtered_args[k] = f"{v}"
# accelerate does not support objects
elif isinstance(v, object):
filtered_args[k] = f"{v}"

return filtered_args

# verify command line args for training
def verify_command_line_training_args(args: argparse.Namespace):
Expand Down
2 changes: 1 addition & 1 deletion sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def optimizer_hook(parameter: torch.Tensor):
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)

# For --sample_at_first
sdxl_train_util.sample_images(
Expand Down
2 changes: 1 addition & 1 deletion sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
)

loss_recorder = train_util.LossRecorder()
Expand Down
2 changes: 1 addition & 1 deletion sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
)

loss_recorder = train_util.LossRecorder()
Expand Down
2 changes: 1 addition & 1 deletion train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
)

loss_recorder = train_util.LossRecorder()
Expand Down
2 changes: 1 addition & 1 deletion train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def train(args):
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)

# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
Expand Down
2 changes: 1 addition & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def load_model_hook(models, input_dir):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
"network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
)

loss_recorder = train_util.LossRecorder()
Expand Down
2 changes: 1 addition & 1 deletion train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def train(self, args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
)

# function for saving/removing
Expand Down
2 changes: 1 addition & 1 deletion train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
)

# function for saving/removing
Expand Down

0 comments on commit 47187f7

Please sign in to comment.