-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for optimizers which don't have "fused" parameter such as grokadamw and 8bit bnb #1744
Conversation
litgpt/utils.py
Outdated
if "." in optimizer: | ||
class_module, class_name = optimizer.rsplit(".", 1) | ||
else: | ||
class_module, class_name = "torch.optim", optimizer | ||
|
||
module = __import__(class_module, fromlist=[class_name]) | ||
optimizer_cls = getattr(module, class_name) | ||
|
||
if "fused" in kwargs and "fused" not in inspect.signature(optimizer_cls).parameters: | ||
kwargs = dict(kwargs) # copy | ||
del kwargs["fused"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I am thinking maybe we can make this even more general so that it would also remove other unsupported arguments if present, something like
if "." in optimizer: | |
class_module, class_name = optimizer.rsplit(".", 1) | |
else: | |
class_module, class_name = "torch.optim", optimizer | |
module = __import__(class_module, fromlist=[class_name]) | |
optimizer_cls = getattr(module, class_name) | |
if "fused" in kwargs and "fused" not in inspect.signature(optimizer_cls).parameters: | |
kwargs = dict(kwargs) # copy | |
del kwargs["fused"] | |
if "." in optimizer: | |
class_module, class_name = optimizer.rsplit(".", 1) | |
else: | |
class_module, class_name = "torch.optim", optimizer | |
module = __import__(class_module, fromlist=[class_name]) | |
optimizer_cls = getattr(module, class_name) | |
valid_params = set(inspect.signature(optimizer_cls).parameters) | |
kwargs = {key: value for key, value in kwargs.items() if key in valid_params} | |
optimizer = optimizer_cls(model_parameters, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR, this looks really good to me. What do you think about the suggestions above?
Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
@mtasic85 I noticed that elsewhere too. It's related to the recent transformers release (4.45) yesterday that introduced a backward incompatible change. You are right, it has nothing to do with your PR. And no worries, I will fix this now. |
Alright! So there were two things that broke CI yesterday: a new transformers release and a new jsonargparse release. Should be all addressed now. |
Looks good now, it should be ready to merge. Thanks again for submitting this PR! |
This fixes #1743
optimizer:
isn'tstr
nordict
, error is thrown.