Skip to content

Commit

Permalink
Fix error importing torch.optim.lr_scheduler.LRScheduler in PyTorch…
Browse files Browse the repository at this point in the history
… 1.13
  • Loading branch information
yqzhishen committed Aug 25, 2023
1 parent 57087ca commit 431862f
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,6 @@ def load_ckpt(
print(f'| load {shown_model_name} from \'{checkpoint_path}\'.')







# return load_pre_train_model()


def remove_padding(x, padding_idx=0):
if x is None:
return None
Expand Down Expand Up @@ -273,6 +265,13 @@ def build_object_from_class_name(cls_str, parent_cls, *args, **kwargs):


def build_lr_scheduler_from_config(optimizer, scheduler_args):
try:
# PyTorch 2.0+
from torch.optim.lr_scheduler import LRScheduler as LRScheduler
except ImportError:
# PyTorch 1.X
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler

def helper(params):
if isinstance(params, list):
return [helper(s) for s in params]
Expand All @@ -287,18 +286,19 @@ def helper(params):
resolved['optimizer'] = optimizer
obj = build_object_from_class_name(
resolved['cls'],
torch.optim.lr_scheduler.LRScheduler,
LRScheduler,
**resolved
)
return obj
return resolved
else:
return params

resolved = helper(scheduler_args)
resolved['optimizer'] = optimizer
return build_object_from_class_name(
scheduler_args['scheduler_cls'],
torch.optim.lr_scheduler.LRScheduler,
LRScheduler,
**resolved
)

Expand Down

0 comments on commit 431862f

Please sign in to comment.