diff --git a/utils/__init__.py b/utils/__init__.py index df052d37..9dc7b4d8 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -211,14 +211,6 @@ def load_ckpt( print(f'| load {shown_model_name} from \'{checkpoint_path}\'.') - - - - - - # return load_pre_train_model() - - def remove_padding(x, padding_idx=0): if x is None: return None @@ -273,6 +265,13 @@ def build_object_from_class_name(cls_str, parent_cls, *args, **kwargs): def build_lr_scheduler_from_config(optimizer, scheduler_args): + try: + # PyTorch 2.0+ + from torch.optim.lr_scheduler import LRScheduler as LRScheduler + except ImportError: + # PyTorch 1.X + from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + def helper(params): if isinstance(params, list): return [helper(s) for s in params] @@ -287,18 +286,19 @@ def helper(params): resolved['optimizer'] = optimizer obj = build_object_from_class_name( resolved['cls'], - torch.optim.lr_scheduler.LRScheduler, + LRScheduler, **resolved ) return obj return resolved else: return params + resolved = helper(scheduler_args) resolved['optimizer'] = optimizer return build_object_from_class_name( scheduler_args['scheduler_cls'], - torch.optim.lr_scheduler.LRScheduler, + LRScheduler, **resolved )