diff --git a/library/train_util.py b/library/train_util.py index b40945ab8..47c367683 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3014,7 +3014,11 @@ def int_or_float(value): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, AdEMAMix8bit, PagedAdEMAMix8bit, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, " + "Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, " + "DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, " + "AdaFactor. " + "Also, you can use any optimizer by specifying the full path to the class, like 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit'.", ) # backward compatibility @@ -4105,6 +4109,7 @@ def get_optimizer(args, trainable_params): lr = args.learning_rate optimizer = None + optimizer_class = None if optimizer_type == "Lion".lower(): try: @@ -4162,7 +4167,8 @@ def get_optimizer(args, trainable_params): "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + if optimizer_class is not None: + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): logger.info(f"use PagedAdamW optimizer | {optimizer_kwargs}") @@ -4338,6 +4344,7 @@ def get_optimizer(args, trainable_params): optimizer_class = getattr(optimizer_module, optimizer_type) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + # for logging optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])