-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Possible bug: Torch backend backpropagation #604
Comments
This is not intended to be supported at this time. Only |
Is it possible to make new KerasCore layers abstracting torch Modules? |
It's possible, but you'd need to wrap the module in a layer in a way that makes it track the underlying |
This would be a valuable feature, making |
Would you like to explore implementing a prototype of a |
Yes! This is something I would love to explore. |
It's probably going to be something like this (warning: entirely untested): class TorchModuleWarpper(Layer):
def __init__(self, module, name=None):
super().__init__(name=name)
self.module = module
def parameters(self, recurse=True):
return self.module.parameters(recurse=recurse)
def build(self, _):
if not self.built:
for param in self.module.parameters():
variable = Variable(value=param, trainable=param.requires_grad)
self._track_variable(variable)
self.built = True
def call(self, *args, **kwargs):
return self.module.forward(*args, **kwargs) |
Would be a very welcome feature, just using .fit, .evaluate() etc on for example all models in https://docs.monai.io/en/stable/networks.html |
I attempted to make the snippet shared by @fchollet work, here's currently where I'm stuck, some issue with inferring the shapes. Even specifically overriding the |
@soumik12345 I have it working with a couple of simple changes:
https://colab.research.google.com/drive/1gG93Fb03Ef-77suS6b7PqgWtMOopNyyF?usp=sharing Update -- actually there is a problem with variable tracking. It doesn't yet work. Udpate -- fixed it by setting |
@fchollet |
Also, should this be an officially supported feature on Keras Core? If not a feature, maybe this could be mentioned in a guide? |
Sure, let's add it. First, please investigate the case where not all module parameters are created upon instantiation (lazy module). When that works, we can start a PR. |
Hi @fchollet Linked Notebook: https://colab.research.google.com/drive/124Vn_d7_WvE2UaieLaG-pQavQuB7xHV3?usp=sharing |
Thanks for the analysis! Is there a way to merge both classes into a single one? Instead of doing an eager |
Yes, let me attempt this. |
Here's how the merged class TorchModuleWarpper(keras.layers.Layer):
def __init__(self, module, name=None):
super().__init__(name=name)
self.module = module.to("cuda")
self.lazy = isinstance(self.module, nn.modules.lazy.LazyModuleMixin)
if not self.lazy:
self.track_module_parameters()
def parameters(self, recurse=True):
return self.module.parameters(recurse=recurse)
def track_module_parameters(self):
for param in self.module.parameters():
variable = keras.backend.Variable(
initializer=param, trainable=param.requires_grad
)
variable._value = param
self._track_variable(variable)
self.built = True
def build(self, input_shape):
sample_input = torch.ones(*input_shape).to("cuda")
_ = self.module(sample_input)
self.track_module_parameters()
def call(self, inputs, *args, **kwargs):
if not self.built:
self.build(inputs.shape[1:])
return self.module.forward(inputs, *args, **kwargs) Here's the updated notebook: https://colab.research.google.com/drive/1vy11aILVgrPtxKHAiohtKM6LlmpVj2EC?usp=sharing
I couldn't find a function |
It would be In fact you can just reuse |
Hi @fchollet class TorchModuleWarpper(keras.layers.Layer):
def __init__(self, module, name=None):
super().__init__(name=name)
self.module = module.to("cuda")
self.lazy = isinstance(self.module, nn.modules.lazy.LazyModuleMixin)
if not self.lazy:
self.track_module_parameters()
def parameters(self, recurse=True):
return self.module.parameters(recurse=not self.lazy)
def track_module_parameters(self):
for param in self.module.parameters():
variable = keras.Variable(
initializer=param, trainable=param.requires_grad
)
variable._value = param
self._track_variable(variable)
self.built = True
def build(self, input_shape, *args, **kwargs):
if not self.lazy:
self._build_by_run_for_single_pos_arg(args)
self._build_by_run_for_kwargs(kwargs)
else:
sample_input = torch.ones(*input_shape).to("cuda")
_ = self.module(sample_input)
self.track_module_parameters()
def call(self, inputs, *args, **kwargs):
if not self.built:
self.build(inputs.shape[1:])
return self.module.forward(inputs, *args, **kwargs) Some observations wrt the following:
It's probably not possible to build the parameters for lazy torch modules except by doing an eager
Would love to know your thoughts on this. PS: Apologies for the delayed response. |
@soumik12345 thanks -- please open a PR and let's move the discussion to the PR. This should live in |
I was attempting to replicate the
demo_mnist_convnet
example by replacingkeras_core.layers
withtorch.nn.Module
inside akeras_core.Model
. Although training withmodel.fit
doesn't encounter any error, it seems that the model is not learning anything.I further attempted to override the
train_step
function according to the guide custom_train_step_in_torch, but that too did not fix the issue. I also tried to explicitly calltorch.nn.Module.zero_grad(self)
as recommended in the same guide, without any effect.Colab to reproduce the issue
The text was updated successfully, but these errors were encountered: