-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Design suggestion: remove forward from list of methods to override #838
Comments
I recently started using Lightning for a project I have been working on and I needed to import the model from a seperate module like you stated @elistevens. In my Lightning init i just instantiate my external model and override the forward to return mymode(x). This works fine, however, I agree that it might be better to have the model as an attribute as opposed to Lightning being the model. |
This would also help with more complicated research projects that involve multiple models (autoencoders or GANs, for example) and make things a lot more flexible and pythonic, "pytorchic." |
@darwinkim I agree with @elistevens that it will be useful to be able to "extract" a more lightweight to ship to production. However, can you provide an example as to how this would help with more complicated research projects that involve multiple models? There's a GAN example here which shows how you can cleanly incorporate multiple models. |
I wonder if there's a way we could expose a |
@jeremyjordan |
just a note: PR #1211 promotes the use of self(...) instead of self.forward in examples and docs. |
@williamFalcon @PyTorchLightning/core-contributors ^^ |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
@williamFalcon asked me to revisit this, so I'm adding some more thoughts. PR #1211 fixed the issue of suggesting that users call Essentially, I'm wanting to clearly and cleanly separate concerns, and have that clean separation be suggested by the documentation. From a OOP perspective, the documentation suggests that the training loop object and the model object be the same object and that mixes two separate concepts. Put another way, if you were going to be using a stock model from How it's suggested now, it becomes much harder to pull out my model and use it in some other context (like a different training loop). I typically try to avoid libraries with that kind of lock-in. |
Thanks for adding more details! I use lightning a lot the way you describe. What gives you the impression that you can’t use it this way? Is there a better way to show this in the docs or examples? First, take a look at all the bolts models. Most models in bolts have that pattern. Second, when you are done you can load the full thing and pull out whatever parts are interesting to you (ie: just the encoder of a GAN), or make the forward use only the encoder. But yeah, you can always drop a model into a lightningModule and use the lightningmodule purely as training loops for the model. https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#byol Finally, we can make some example or write in the docs what you’re more clearly looking for if you’d prefer: class ClassificationTask(LightningModule):
model = Resnet50() In fact, we can add a new section to bolts with these prebuilt loops. Classification loops, fine-tuning loop, etc... |
Can we at least raise a NotImplementedError like PyTorch does? I only just now noticed that in the current version, LightningModule actually implements forward for you to return None. Why is that? |
it does but forward is not required... we want to separate training from inference. in training you use the __step methods. if your model also happens to do inference, then it should implement forward. this makes a clean separation between training scripts purely and models. this removal also enables tasks which weren’t possible before. |
All of this is clear. No problem with that. If you don't use forward all is good. |
class Lightning(LightningModule):
pass
class Torch(nn.Module):
pass
lightning_model = Lightning()
print(lightning_model(torch.rand(2, 2))) # does not raise, returns None, why?
torch_model = Torch()
print(torch_model(torch.rand(2, 2))) # raises NotImplementedError, good! |
Hey, Any updates on this issue ? Best, |
all items here were addressed a while ago. we can close
|
🚀 Feature (?)
This is more a philosophical design suggestion than a feature request.
I think that the presentation of
LightningModule
as atorch.Module
-plus-features encourages early experiment designs that don't refactor nicely as the projects using it grow.I also think that calling
self.forward
directly is a torch anti-pattern, and should not be encouraged.I'd like the official docs to suggest using a
self.my_model = MyModel(...)
in__init__
andy = self.my_model(x)
intraining_step
etc.Motivation
I think that most non-research uses of lightning are going to require that the environment the model is trained in be separable from the model itself. This is most obvious when considering the infrastructure needed to load training data vs. production inference data; you're not going to want to drag along all of the libraries needed to connect to a database, decompress data, etc. in the production environment.
To do so, I'd need to be able to
from some.other.package import MyModel
and thenself.my_model = MyModel(...)
in__init__
. As long assome.other.package
doesn't have extra dependencies, I can ship my production model and weights to production without needing everything else that lightning, etc. depends on.By suggesting that users have the lightning subclass be the model, the set of packages that need to be present in production goes up quite a bit (speaking from experience, the pip version management becomes painful).
Another thing that this makes unclear, then, is what is actually happening when
training_step
gets called. The suggestion "Normally you'd call self.forward() from your training_step() method." implies thatself.training_step
is happening inside of aself.__call__
sincetorch.nn.Module.forward
isn't supposed to be called directly (since it's__call__
that handles hooks, etc.), but that doesn't actually seem to be the case. Unless I'm missing something, this really feels like misuse of the torch API.By making it clear that your
LightningModule
subclass should have an instance of your model as an attribute, not be the model, all of the above gets cleared up quite a bit.Pitch
I think it's a lot cleaner and clearer to say "Normally you'd call
y = self.my_model(x)
from your training_step() method." and remove any suggestion of overridingself.forward()
from the documentation (and I'd in fact make the default implementation of forward raise aYouAreDoingItWrongException
).As I said earlier, I think that projects that mix training and model code in the same class are going to have a difficult time refactoring things later on, and I think that the perceived simplicity early on is a mastery trap. Anyone familiar with PyTorch isn't going to have a problem defining a separate model class.
Alternatives
Note that I don't think there's anything preventing me from implementing models the way I think is proper right now, but I'm currently doing an investigation into if we can use lightning for more projects in the organization, and I'd really rather not having to try and educate users to ignore the docs and do it the
self.my_model()
way instead.At the very least, changing the documentation to say "Normally you'd call
y = self(x)
from your training_step() method." makes sure that hooks, etc. get called as expected.Additional context
Now, I will fully admit that I haven't dug into lightning a ton yet, so it's possible that I'm missing something that will change my understanding/perception of things. If that's the case, I think it should be articulated more clearly.
Thanks for reading.
The text was updated successfully, but these errors were encountered: