Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible bug: Torch backend backpropagation #604

Closed
soumik12345 opened this issue Jul 25, 2023 · 20 comments
Closed

Possible bug: Torch backend backpropagation #604

soumik12345 opened this issue Jul 25, 2023 · 20 comments

Comments

@soumik12345
Copy link
Contributor

soumik12345 commented Jul 25, 2023

I was attempting to replicate the demo_mnist_convnet example by replacing keras_core.layers with torch.nn.Module inside a keras_core.Model. Although training with model.fit doesn't encounter any error, it seems that the model is not learning anything.

I further attempted to override the train_step function according to the guide custom_train_step_in_torch, but that too did not fix the issue. I also tried to explicitly call torch.nn.Module.zero_grad(self) as recommended in the same guide, without any effect.

Colab to reproduce the issue

@fchollet
Copy link
Member

replacing keras_core.layers with torch.nn.Module inside a keras_core.Model

This is not intended to be supported at this time. Only Layer objects are tracked by Keras, not torch Modules.

@soumik12345
Copy link
Contributor Author

This is not intended to be supported at this time. Only Layer objects are tracked by Keras, not torch Modules.

Is it possible to make new KerasCore layers abstracting torch Modules?

@fchollet
Copy link
Member

It's possible, but you'd need to wrap the module in a layer in a way that makes it track the underlying parameters(). We haven't implemented anything like this yet. But it's likely simple.

@soumik12345
Copy link
Contributor Author

It's possible, but you'd need to wrap the module in a layer in a way that makes it track the underlying parameters(). We haven't implemented anything like this yet. But it's likely simple.

This would be a valuable feature, making keras_core compatible with likely the entirety of the PyTorch ecosystem.

@fchollet
Copy link
Member

Would you like to explore implementing a prototype of a TorchModuleWrapper? If it's simple, we can auto-wrap any module upon assignment to a Keras Layer with the torch backend.

@soumik12345
Copy link
Contributor Author

Would you like to explore implementing a prototype of a TorchModuleWrapper? If it's simple, we can auto-wrap any module upon assignment to a Keras Layer with the torch backend.

Yes! This is something I would love to explore.

@fchollet
Copy link
Member

It's probably going to be something like this (warning: entirely untested):

class TorchModuleWarpper(Layer):
    def __init__(self, module, name=None):
        super().__init__(name=name)
        self.module = module

    def parameters(self, recurse=True):
        return self.module.parameters(recurse=recurse)

    def build(self, _):
        if not self.built:
            for param in self.module.parameters():
                variable = Variable(value=param, trainable=param.requires_grad)
                self._track_variable(variable)
        self.built = True

    def call(self, *args, **kwargs):
        return self.module.forward(*args, **kwargs)

@dmus
Copy link

dmus commented Jul 31, 2023

Would be a very welcome feature, just using .fit, .evaluate() etc on for example all models in https://docs.monai.io/en/stable/networks.html

@soumik12345
Copy link
Contributor Author

It's probably going to be something like this (warning: entirely untested):

I attempted to make the snippet shared by @fchollet work, here's currently where I'm stuck, some issue with inferring the shapes. Even specifically overriding the compute_output_shape function for the TorchModuleWarpper doesn't seem to work. I'll be investigating it further.

@fchollet
Copy link
Member

fchollet commented Jul 31, 2023

@soumik12345 I have it working with a couple of simple changes:

  • The module wrapper can set self.built = True since torch module weights are already created (mind you, I guess this won't be the case with lazy modules, I have no idea how those work)
  • There was a device placement issue, as is common with torch. I placed the module on GPU to make it work.

https://colab.research.google.com/drive/1gG93Fb03Ef-77suS6b7PqgWtMOopNyyF?usp=sharing

Update -- actually there is a problem with variable tracking. It doesn't yet work.

Udpate -- fixed it by setting variable._value directly. It's training now.

@soumik12345
Copy link
Contributor Author

@fchollet
Could you please make the notebook public? 😅

@soumik12345
Copy link
Contributor Author

Also, should this be an officially supported feature on Keras Core? If not a feature, maybe this could be mentioned in a guide?

@fchollet
Copy link
Member

Also, should this be an officially supported feature on Keras Core? If not a feature, maybe this could be mentioned in a guide?

Sure, let's add it. First, please investigate the case where not all module parameters are created upon instantiation (lazy module). When that works, we can start a PR.

@soumik12345
Copy link
Contributor Author

Hi @fchollet
I figured out how to make lazy torch modules train with Keras by slightly modifying the proposed TorchModuleWarpper. Would love to know your thoughts.

Linked Notebook: https://colab.research.google.com/drive/124Vn_d7_WvE2UaieLaG-pQavQuB7xHV3?usp=sharing

@fchollet
Copy link
Member

fchollet commented Aug 8, 2023

Thanks for the analysis! Is there a way to merge both classes into a single one?

Instead of doing an eager _ = self.module(sample_input), we should leverage backend.compute_output_spec which will be more efficient.

@soumik12345
Copy link
Contributor Author

Thanks for the analysis! Is there a way to merge both classes into a single one?

Yes, let me attempt this.

@soumik12345
Copy link
Contributor Author

soumik12345 commented Aug 9, 2023

Is there a way to merge both classes into a single one?

Here's how the merged TorchModuleWarpper is looking

class TorchModuleWarpper(keras.layers.Layer):
    def __init__(self, module, name=None):
        super().__init__(name=name)
        self.module = module.to("cuda")
        self.lazy = isinstance(self.module, nn.modules.lazy.LazyModuleMixin)
        if not self.lazy:
            self.track_module_parameters()

    def parameters(self, recurse=True):
        return self.module.parameters(recurse=recurse)
    
    def track_module_parameters(self):
        for param in self.module.parameters():
            variable = keras.backend.Variable(
                initializer=param, trainable=param.requires_grad
            )
            variable._value = param
            self._track_variable(variable)
        self.built = True
    
    def build(self, input_shape):
        sample_input = torch.ones(*input_shape).to("cuda")
        _ = self.module(sample_input)
        self.track_module_parameters()

    def call(self, inputs, *args, **kwargs):
        if not self.built:
            self.build(inputs.shape[1:])
        return self.module.forward(inputs, *args, **kwargs)

Here's the updated notebook: https://colab.research.google.com/drive/1vy11aILVgrPtxKHAiohtKM6LlmpVj2EC?usp=sharing

Instead of doing an eager _ = self.module(sample_input), we should leverage backend.compute_output_spec which will be more efficient.

I couldn't find a function backend.compute_output_spec, do you mean that we should override keras_core.layers.Layer.compute_output_spec?

@fchollet
Copy link
Member

fchollet commented Aug 9, 2023

I couldn't find a function backend.compute_output_spec, do you mean that we should override keras_core.layers.Layer.compute_output_spec?

It would be backend.compute_output_spec(self.__call__, *args, **kwargs) where *args, **kwargs would be KerasTensors (and other arguments).

In fact you can just reuse Layer._build_by_run_for_single_pos_arg() and Layer._build_by_run_for_kwargs (which we probably need to expose via a single unified method).

@soumik12345
Copy link
Contributor Author

Hi @fchollet
Here's the TorchModuleWarpper modified as per your feedback (with some generous help from @ariG23498)

class TorchModuleWarpper(keras.layers.Layer):
    def __init__(self, module, name=None):
        super().__init__(name=name)
        self.module = module.to("cuda")
        self.lazy = isinstance(self.module, nn.modules.lazy.LazyModuleMixin)
        if not self.lazy:
            self.track_module_parameters()

    def parameters(self, recurse=True):
        return self.module.parameters(recurse=not self.lazy)
    
    def track_module_parameters(self):
        for param in self.module.parameters():
            variable = keras.Variable(
                initializer=param, trainable=param.requires_grad
            )
            variable._value = param
            self._track_variable(variable)
        self.built = True
    
    def build(self, input_shape, *args, **kwargs):
        if not self.lazy:
            self._build_by_run_for_single_pos_arg(args)
            self._build_by_run_for_kwargs(kwargs)
        else:
            sample_input = torch.ones(*input_shape).to("cuda")
            _ = self.module(sample_input)
        self.track_module_parameters()

    def call(self, inputs, *args, **kwargs):
        if not self.built:
            self.build(inputs.shape[1:])
        return self.module.forward(inputs, *args, **kwargs)

Some observations wrt the following:

Instead of doing an eager _ = self.module(sample_input), we should leverage backend.compute_output_spec which will be more efficient.

It's probably not possible to build the parameters for lazy torch modules except by doing an eager _ = self.module(sample_input) since they don't initialize the parameters unless there's a forward pass. This is what the official docs mention...

Modules that lazily initialize parameters, or “lazy modules”, derive the shapes of their parameters from the first input(s) to their forward method. Until that first forward they contain torch.nn.UninitializedParameters that should not be accessed or used, and afterward they contain regular torch.nn.Parameters.

Would love to know your thoughts on this.

PS: Apologies for the delayed response.

@fchollet
Copy link
Member

@soumik12345 thanks -- please open a PR and let's move the discussion to the PR. This should live in backend/torch/torch_module_wrapper.py. There is some light refactoring we will need.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants