From dbc52fe6ac74db76b6b1f9a95d042127e29752b0 Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Wed, 23 Aug 2023 07:00:28 +0000 Subject: [PATCH 01/11] add: TorchModuleWarpper --- keras_core/backend/torch/__init__.py | 1 + .../backend/torch/torch_module_wrapper.py | 160 ++++++++++++++++++ 2 files changed, 161 insertions(+) create mode 100644 keras_core/backend/torch/torch_module_wrapper.py diff --git a/keras_core/backend/torch/__init__.py b/keras_core/backend/torch/__init__.py index a26f5f647..e1f91ee59 100644 --- a/keras_core/backend/torch/__init__.py +++ b/keras_core/backend/torch/__init__.py @@ -36,3 +36,4 @@ from keras_core.backend.torch.rnn import gru from keras_core.backend.torch.rnn import lstm from keras_core.backend.torch.rnn import rnn +from keras_core.backend.torch.torch_module_wrapper import TorchModuleWarpper diff --git a/keras_core/backend/torch/torch_module_wrapper.py b/keras_core/backend/torch/torch_module_wrapper.py new file mode 100644 index 000000000..32ad7d7fa --- /dev/null +++ b/keras_core/backend/torch/torch_module_wrapper.py @@ -0,0 +1,160 @@ +import torch +import torch.nn as nn + +from keras_core.layers import Layer +from keras_core.backend import Variable +from keras_core.api_export import keras_core_export + + +@keras_core_export(["keras_core.backend.torch.TorchModuleWarpper"]) +class TorchModuleWarpper(Layer): + """Torch module wrapper layer. + + `TorchModuleWarpper` is an abstraction that can be wrapped around a + `torch.nn.Module` to make its parameters trackable as a + `keras_core.layers.Layer`. It works with both vanilla and lazy PyTorch + modules. + + Args: + module: torch.nn.Module, A vanilla or lazy PyTorch neural network module. + name: The name of the layer (string). + + References: + - [PyTorch docs for `torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) + - [PyTorch docs for `LazyModuleMixin`](https://pytorch.org/docs/stable/generated/torch.nn.modules.lazy.LazyModuleMixin.html) + + Examples: + + Here's an example of how the `TorchModuleWarpper` can be used with vanilla PyTorch + modules. + + ```python + import torch.nn as nn + import torch.nn.functional as F + + import keras_core + from keras_core.backend.torch import TorchModuleWarpper + + + class Classifier(keras_core.Model): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Wrap all `torch.nn.Module`s with `TorchModuleWarpper` + self.conv1 = TorchModuleWarpper( + nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3)) + ) + self.conv2 = TorchModuleWarpper( + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3)) + ) + self.pool = TorchModuleWarpper( + nn.MaxPool2d(kernel_size=(2, 2)) + ) + self.flatten = TorchModuleWarpper(nn.Flatten()) + self.dropout = TorchModuleWarpper(nn.Dropout(p=0.5)) + self.fc = TorchModuleWarpper(nn.Linear(1600, 10)) + + def call(self, inputs): + x = F.relu(self.conv1(inputs)) + x = self.pool(x) + x = F.relu(self.conv2(x)) + x = self.pool(x) + x = self.flatten(x) + x = self.dropout(x) + x = self.fc(x) + return F.softmax(x, dim=1) + + + model = Classifier() + model.build((1, 28, 28)) + print("Output shape:", model(torch.ones(1, 1, 28, 28).to("cuda")).shape) + + model.compile( + loss="sparse_categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"] + ) + model.fit(train_loader, epochs=5) + ``` + + Here's an example of how the `TorchModuleWarpper` can be used with PyTorch + Lazy modules. + + ```python + import torch.nn as nn + import torch.nn.functional as F + + import keras_core + from keras_core.backend.torch import TorchModuleWarpper + + + class LazyClassifier(keras.Model): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # You can wrap all `torch.nn.Module`s with `TorchModuleWarpper` + # irrespective of whether they are lazy or not. + self.conv1 = TorchModuleWarpper( + nn.LazyConv2d(out_channels=32, kernel_size=(3, 3)) + ) + self.conv2 = TorchModuleWarpper( + nn.LazyConv2d(out_channels=64, kernel_size=(3, 3)) + ) + self.pool = TorchModuleWarpper(nn.MaxPool2d(kernel_size=(2, 2))) + self.flatten = TorchModuleWarpper(nn.Flatten()) + self.dropout = TorchModuleWarpper(nn.Dropout(p=0.5)) + self.fc = TorchModuleWarpper(nn.LazyLinear(10)) + + def call(self, inputs): + x = F.relu(self.conv1(inputs)) + x = self.pool(x) + x = F.relu(self.conv2(x)) + x = self.pool(x) + x = self.flatten(x) + x = self.dropout(x) + x = self.fc(x) + return F.softmax(x, dim=1) + + + model = Classifier() + model.build((1, 28, 28)) + print("Output shape:", model(torch.ones(1, 1, 28, 28).to("cuda")).shape) + + model.compile( + loss="sparse_categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"] + ) + model.fit(train_loader, epochs=5) + ``` + + """ + + def __init__(self, module, name=None): + super().__init__(name=name) + self.module = module.to("cuda") + self.lazy = isinstance(self.module, nn.modules.lazy.LazyModuleMixin) + if not self.lazy: + self.track_module_parameters() + + def parameters(self, recurse=True): + return self.module.parameters(recurse=recurse) + + def track_module_parameters(self): + for param in self.module.parameters(): + variable = Variable( + initializer=param, trainable=param.requires_grad + ) + variable._value = param + self._track_variable(variable) + self.built = True + + def build(self, input_shape): + sample_input = torch.ones(*input_shape).to("cuda") + _ = self.module(sample_input) + self.track_module_parameters() + + def call(self, inputs, **kwargs): + if not self.built: + self.build(inputs.shape[1:]) + return self.module.forward(inputs, **kwargs) From a5317950c483f9020eb2bbdaafa1616b0c6f51ac Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Wed, 23 Aug 2023 07:10:37 +0000 Subject: [PATCH 02/11] chore: make ci happy --- keras_core/backend/torch/torch_module_wrapper.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/keras_core/backend/torch/torch_module_wrapper.py b/keras_core/backend/torch/torch_module_wrapper.py index 32ad7d7fa..a90852221 100644 --- a/keras_core/backend/torch/torch_module_wrapper.py +++ b/keras_core/backend/torch/torch_module_wrapper.py @@ -16,17 +16,18 @@ class TorchModuleWarpper(Layer): modules. Args: - module: torch.nn.Module, A vanilla or lazy PyTorch neural network module. + module: torch.nn.Module, A vanilla or lazy PyTorch neural network + module. name: The name of the layer (string). References: - - [PyTorch docs for `torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) - - [PyTorch docs for `LazyModuleMixin`](https://pytorch.org/docs/stable/generated/torch.nn.modules.lazy.LazyModuleMixin.html) + - [PyTorch docs for `torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) # noqa: E501 + - [PyTorch docs for `LazyModuleMixin`](https://pytorch.org/docs/stable/generated/torch.nn.modules.lazy.LazyModuleMixin.html) # noqa: E501 Examples: - Here's an example of how the `TorchModuleWarpper` can be used with vanilla PyTorch - modules. + Here's an example of how the `TorchModuleWarpper` can be used with vanilla + PyTorch modules. ```python import torch.nn as nn From 93c1b5e35b154413ba6a7bb444f9e4fed27f6785 Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Sat, 26 Aug 2023 13:30:58 +0000 Subject: [PATCH 03/11] fix: typo in TorchModuleWrapper --- keras_core/backend/torch/__init__.py | 2 +- .../backend/torch/torch_module_wrapper.py | 42 +++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/keras_core/backend/torch/__init__.py b/keras_core/backend/torch/__init__.py index e1f91ee59..604b6217b 100644 --- a/keras_core/backend/torch/__init__.py +++ b/keras_core/backend/torch/__init__.py @@ -36,4 +36,4 @@ from keras_core.backend.torch.rnn import gru from keras_core.backend.torch.rnn import lstm from keras_core.backend.torch.rnn import rnn -from keras_core.backend.torch.torch_module_wrapper import TorchModuleWarpper +from keras_core.backend.torch.torch_module_wrapper import TorchModuleWrapper diff --git a/keras_core/backend/torch/torch_module_wrapper.py b/keras_core/backend/torch/torch_module_wrapper.py index a90852221..8206f5f45 100644 --- a/keras_core/backend/torch/torch_module_wrapper.py +++ b/keras_core/backend/torch/torch_module_wrapper.py @@ -6,11 +6,11 @@ from keras_core.api_export import keras_core_export -@keras_core_export(["keras_core.backend.torch.TorchModuleWarpper"]) -class TorchModuleWarpper(Layer): +@keras_core_export(["keras_core.backend.torch.TorchModuleWrapper"]) +class TorchModuleWrapper(Layer): """Torch module wrapper layer. - `TorchModuleWarpper` is an abstraction that can be wrapped around a + `TorchModuleWrapper` is an abstraction that can be wrapped around a `torch.nn.Module` to make its parameters trackable as a `keras_core.layers.Layer`. It works with both vanilla and lazy PyTorch modules. @@ -26,7 +26,7 @@ class TorchModuleWarpper(Layer): Examples: - Here's an example of how the `TorchModuleWarpper` can be used with vanilla + Here's an example of how the `TorchModuleWrapper` can be used with vanilla PyTorch modules. ```python @@ -34,26 +34,26 @@ class TorchModuleWarpper(Layer): import torch.nn.functional as F import keras_core - from keras_core.backend.torch import TorchModuleWarpper + from keras_core.backend.torch import TorchModuleWrapper class Classifier(keras_core.Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # Wrap all `torch.nn.Module`s with `TorchModuleWarpper` - self.conv1 = TorchModuleWarpper( + # Wrap all `torch.nn.Module`s with `TorchModuleWrapper` + self.conv1 = TorchModuleWrapper( nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3)) ) - self.conv2 = TorchModuleWarpper( + self.conv2 = TorchModuleWrapper( nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3)) ) - self.pool = TorchModuleWarpper( + self.pool = TorchModuleWrapper( nn.MaxPool2d(kernel_size=(2, 2)) ) - self.flatten = TorchModuleWarpper(nn.Flatten()) - self.dropout = TorchModuleWarpper(nn.Dropout(p=0.5)) - self.fc = TorchModuleWarpper(nn.Linear(1600, 10)) + self.flatten = TorchModuleWrapper(nn.Flatten()) + self.dropout = TorchModuleWrapper(nn.Dropout(p=0.5)) + self.fc = TorchModuleWrapper(nn.Linear(1600, 10)) def call(self, inputs): x = F.relu(self.conv1(inputs)) @@ -78,7 +78,7 @@ def call(self, inputs): model.fit(train_loader, epochs=5) ``` - Here's an example of how the `TorchModuleWarpper` can be used with PyTorch + Here's an example of how the `TorchModuleWrapper` can be used with PyTorch Lazy modules. ```python @@ -86,25 +86,25 @@ def call(self, inputs): import torch.nn.functional as F import keras_core - from keras_core.backend.torch import TorchModuleWarpper + from keras_core.backend.torch import TorchModuleWrapper class LazyClassifier(keras.Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # You can wrap all `torch.nn.Module`s with `TorchModuleWarpper` + # You can wrap all `torch.nn.Module`s with `TorchModuleWrapper` # irrespective of whether they are lazy or not. - self.conv1 = TorchModuleWarpper( + self.conv1 = TorchModuleWrapper( nn.LazyConv2d(out_channels=32, kernel_size=(3, 3)) ) - self.conv2 = TorchModuleWarpper( + self.conv2 = TorchModuleWrapper( nn.LazyConv2d(out_channels=64, kernel_size=(3, 3)) ) - self.pool = TorchModuleWarpper(nn.MaxPool2d(kernel_size=(2, 2))) - self.flatten = TorchModuleWarpper(nn.Flatten()) - self.dropout = TorchModuleWarpper(nn.Dropout(p=0.5)) - self.fc = TorchModuleWarpper(nn.LazyLinear(10)) + self.pool = TorchModuleWrapper(nn.MaxPool2d(kernel_size=(2, 2))) + self.flatten = TorchModuleWrapper(nn.Flatten()) + self.dropout = TorchModuleWrapper(nn.Dropout(p=0.5)) + self.fc = TorchModuleWrapper(nn.LazyLinear(10)) def call(self, inputs): x = F.relu(self.conv1(inputs)) From 868c36ff3fb53b3e0497e716f2bc74a40611c422 Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Sat, 26 Aug 2023 13:35:49 +0000 Subject: [PATCH 04/11] update: Variable import --- keras_core/backend/torch/torch_module_wrapper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras_core/backend/torch/torch_module_wrapper.py b/keras_core/backend/torch/torch_module_wrapper.py index 8206f5f45..2229c331c 100644 --- a/keras_core/backend/torch/torch_module_wrapper.py +++ b/keras_core/backend/torch/torch_module_wrapper.py @@ -2,7 +2,6 @@ import torch.nn as nn from keras_core.layers import Layer -from keras_core.backend import Variable from keras_core.api_export import keras_core_export @@ -142,6 +141,8 @@ def parameters(self, recurse=True): return self.module.parameters(recurse=recurse) def track_module_parameters(self): + from keras_core.backend import Variable + for param in self.module.parameters(): variable = Variable( initializer=param, trainable=param.requires_grad From 116dabab21ca4e86e62db1beeef5397bdb5db86b Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Sat, 26 Aug 2023 13:50:08 +0000 Subject: [PATCH 05/11] update: TorchModuleWrapper --- .../backend/torch/torch_module_wrapper.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/keras_core/backend/torch/torch_module_wrapper.py b/keras_core/backend/torch/torch_module_wrapper.py index 2229c331c..d89be8ddd 100644 --- a/keras_core/backend/torch/torch_module_wrapper.py +++ b/keras_core/backend/torch/torch_module_wrapper.py @@ -138,10 +138,10 @@ def __init__(self, module, name=None): self.track_module_parameters() def parameters(self, recurse=True): - return self.module.parameters(recurse=recurse) + return self.module.parameters(recurse=not self.lazy) def track_module_parameters(self): - from keras_core.backend import Variable + from keras_core.backend.torch import Variable for param in self.module.parameters(): variable = Variable( @@ -151,12 +151,16 @@ def track_module_parameters(self): self._track_variable(variable) self.built = True - def build(self, input_shape): - sample_input = torch.ones(*input_shape).to("cuda") - _ = self.module(sample_input) + def build(self, input_shape, *args, **kwargs): + if not self.lazy: + self._build_by_run_for_single_pos_arg(args) + self._build_by_run_for_kwargs(kwargs) + else: + sample_input = torch.ones(*input_shape).to("cuda") + _ = self.module(sample_input) self.track_module_parameters() - def call(self, inputs, **kwargs): + def call(self, inputs, *args, **kwargs): if not self.built: self.build(inputs.shape[1:]) - return self.module.forward(inputs, **kwargs) + return self.module.forward(inputs, *args, **kwargs) From bef832046abec5c2dda2974c3a21eebbc6eadb84 Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Sat, 26 Aug 2023 14:07:42 +0000 Subject: [PATCH 06/11] update: manage device placement --- keras_core/backend/torch/torch_module_wrapper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras_core/backend/torch/torch_module_wrapper.py b/keras_core/backend/torch/torch_module_wrapper.py index d89be8ddd..8ec85fdec 100644 --- a/keras_core/backend/torch/torch_module_wrapper.py +++ b/keras_core/backend/torch/torch_module_wrapper.py @@ -132,7 +132,8 @@ def call(self, inputs): def __init__(self, module, name=None): super().__init__(name=name) - self.module = module.to("cuda") + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.module = module.to(self.device) self.lazy = isinstance(self.module, nn.modules.lazy.LazyModuleMixin) if not self.lazy: self.track_module_parameters() From 5e43c6062406681319f0d972a7738d4da90f6625 Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Tue, 29 Aug 2023 19:19:38 +0000 Subject: [PATCH 07/11] refactor: TorchModuleWrapper to keras_core.utils.torch_utils --- keras_core/backend/torch/__init__.py | 1 - .../backend/torch/torch_module_wrapper.py | 167 ------------------ keras_core/utils/torch_utils.py | 167 ++++++++++++++++++ 3 files changed, 167 insertions(+), 168 deletions(-) create mode 100644 keras_core/utils/torch_utils.py diff --git a/keras_core/backend/torch/__init__.py b/keras_core/backend/torch/__init__.py index 604b6217b..a26f5f647 100644 --- a/keras_core/backend/torch/__init__.py +++ b/keras_core/backend/torch/__init__.py @@ -36,4 +36,3 @@ from keras_core.backend.torch.rnn import gru from keras_core.backend.torch.rnn import lstm from keras_core.backend.torch.rnn import rnn -from keras_core.backend.torch.torch_module_wrapper import TorchModuleWrapper diff --git a/keras_core/backend/torch/torch_module_wrapper.py b/keras_core/backend/torch/torch_module_wrapper.py index 8ec85fdec..e69de29bb 100644 --- a/keras_core/backend/torch/torch_module_wrapper.py +++ b/keras_core/backend/torch/torch_module_wrapper.py @@ -1,167 +0,0 @@ -import torch -import torch.nn as nn - -from keras_core.layers import Layer -from keras_core.api_export import keras_core_export - - -@keras_core_export(["keras_core.backend.torch.TorchModuleWrapper"]) -class TorchModuleWrapper(Layer): - """Torch module wrapper layer. - - `TorchModuleWrapper` is an abstraction that can be wrapped around a - `torch.nn.Module` to make its parameters trackable as a - `keras_core.layers.Layer`. It works with both vanilla and lazy PyTorch - modules. - - Args: - module: torch.nn.Module, A vanilla or lazy PyTorch neural network - module. - name: The name of the layer (string). - - References: - - [PyTorch docs for `torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) # noqa: E501 - - [PyTorch docs for `LazyModuleMixin`](https://pytorch.org/docs/stable/generated/torch.nn.modules.lazy.LazyModuleMixin.html) # noqa: E501 - - Examples: - - Here's an example of how the `TorchModuleWrapper` can be used with vanilla - PyTorch modules. - - ```python - import torch.nn as nn - import torch.nn.functional as F - - import keras_core - from keras_core.backend.torch import TorchModuleWrapper - - - class Classifier(keras_core.Model): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Wrap all `torch.nn.Module`s with `TorchModuleWrapper` - self.conv1 = TorchModuleWrapper( - nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3)) - ) - self.conv2 = TorchModuleWrapper( - nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3)) - ) - self.pool = TorchModuleWrapper( - nn.MaxPool2d(kernel_size=(2, 2)) - ) - self.flatten = TorchModuleWrapper(nn.Flatten()) - self.dropout = TorchModuleWrapper(nn.Dropout(p=0.5)) - self.fc = TorchModuleWrapper(nn.Linear(1600, 10)) - - def call(self, inputs): - x = F.relu(self.conv1(inputs)) - x = self.pool(x) - x = F.relu(self.conv2(x)) - x = self.pool(x) - x = self.flatten(x) - x = self.dropout(x) - x = self.fc(x) - return F.softmax(x, dim=1) - - - model = Classifier() - model.build((1, 28, 28)) - print("Output shape:", model(torch.ones(1, 1, 28, 28).to("cuda")).shape) - - model.compile( - loss="sparse_categorical_crossentropy", - optimizer="adam", - metrics=["accuracy"] - ) - model.fit(train_loader, epochs=5) - ``` - - Here's an example of how the `TorchModuleWrapper` can be used with PyTorch - Lazy modules. - - ```python - import torch.nn as nn - import torch.nn.functional as F - - import keras_core - from keras_core.backend.torch import TorchModuleWrapper - - - class LazyClassifier(keras.Model): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # You can wrap all `torch.nn.Module`s with `TorchModuleWrapper` - # irrespective of whether they are lazy or not. - self.conv1 = TorchModuleWrapper( - nn.LazyConv2d(out_channels=32, kernel_size=(3, 3)) - ) - self.conv2 = TorchModuleWrapper( - nn.LazyConv2d(out_channels=64, kernel_size=(3, 3)) - ) - self.pool = TorchModuleWrapper(nn.MaxPool2d(kernel_size=(2, 2))) - self.flatten = TorchModuleWrapper(nn.Flatten()) - self.dropout = TorchModuleWrapper(nn.Dropout(p=0.5)) - self.fc = TorchModuleWrapper(nn.LazyLinear(10)) - - def call(self, inputs): - x = F.relu(self.conv1(inputs)) - x = self.pool(x) - x = F.relu(self.conv2(x)) - x = self.pool(x) - x = self.flatten(x) - x = self.dropout(x) - x = self.fc(x) - return F.softmax(x, dim=1) - - - model = Classifier() - model.build((1, 28, 28)) - print("Output shape:", model(torch.ones(1, 1, 28, 28).to("cuda")).shape) - - model.compile( - loss="sparse_categorical_crossentropy", - optimizer="adam", - metrics=["accuracy"] - ) - model.fit(train_loader, epochs=5) - ``` - - """ - - def __init__(self, module, name=None): - super().__init__(name=name) - self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.module = module.to(self.device) - self.lazy = isinstance(self.module, nn.modules.lazy.LazyModuleMixin) - if not self.lazy: - self.track_module_parameters() - - def parameters(self, recurse=True): - return self.module.parameters(recurse=not self.lazy) - - def track_module_parameters(self): - from keras_core.backend.torch import Variable - - for param in self.module.parameters(): - variable = Variable( - initializer=param, trainable=param.requires_grad - ) - variable._value = param - self._track_variable(variable) - self.built = True - - def build(self, input_shape, *args, **kwargs): - if not self.lazy: - self._build_by_run_for_single_pos_arg(args) - self._build_by_run_for_kwargs(kwargs) - else: - sample_input = torch.ones(*input_shape).to("cuda") - _ = self.module(sample_input) - self.track_module_parameters() - - def call(self, inputs, *args, **kwargs): - if not self.built: - self.build(inputs.shape[1:]) - return self.module.forward(inputs, *args, **kwargs) diff --git a/keras_core/utils/torch_utils.py b/keras_core/utils/torch_utils.py new file mode 100644 index 000000000..2dfcd08ff --- /dev/null +++ b/keras_core/utils/torch_utils.py @@ -0,0 +1,167 @@ +import torch +import torch.nn as nn + +from keras_core.layers import Layer +from keras_core.api_export import keras_core_export + + +@keras_core_export(["keras_core.utils.torch_utils.TorchModuleWrapper"]) +class TorchModuleWrapper(Layer): + """Torch module wrapper layer. + + `TorchModuleWrapper` is an abstraction that can be wrapped around a + `torch.nn.Module` to make its parameters trackable as a + `keras_core.layers.Layer`. It works with both vanilla and lazy PyTorch + modules. + + Args: + module: torch.nn.Module, A vanilla or lazy PyTorch neural network + module. + name: The name of the layer (string). + + References: + - [PyTorch docs for `torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) # noqa: E501 + - [PyTorch docs for `LazyModuleMixin`](https://pytorch.org/docs/stable/generated/torch.nn.modules.lazy.LazyModuleMixin.html) # noqa: E501 + + Examples: + + Here's an example of how the `TorchModuleWrapper` can be used with vanilla + PyTorch modules. + + ```python + import torch.nn as nn + import torch.nn.functional as F + + import keras_core + from keras_core.backend.torch import TorchModuleWrapper + + + class Classifier(keras_core.Model): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Wrap all `torch.nn.Module`s with `TorchModuleWrapper` + self.conv1 = TorchModuleWrapper( + nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3)) + ) + self.conv2 = TorchModuleWrapper( + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3)) + ) + self.pool = TorchModuleWrapper( + nn.MaxPool2d(kernel_size=(2, 2)) + ) + self.flatten = TorchModuleWrapper(nn.Flatten()) + self.dropout = TorchModuleWrapper(nn.Dropout(p=0.5)) + self.fc = TorchModuleWrapper(nn.Linear(1600, 10)) + + def call(self, inputs): + x = F.relu(self.conv1(inputs)) + x = self.pool(x) + x = F.relu(self.conv2(x)) + x = self.pool(x) + x = self.flatten(x) + x = self.dropout(x) + x = self.fc(x) + return F.softmax(x, dim=1) + + + model = Classifier() + model.build((1, 28, 28)) + print("Output shape:", model(torch.ones(1, 1, 28, 28).to("cuda")).shape) + + model.compile( + loss="sparse_categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"] + ) + model.fit(train_loader, epochs=5) + ``` + + Here's an example of how the `TorchModuleWrapper` can be used with PyTorch + Lazy modules. + + ```python + import torch.nn as nn + import torch.nn.functional as F + + import keras_core + from keras_core.backend.torch import TorchModuleWrapper + + + class LazyClassifier(keras.Model): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # You can wrap all `torch.nn.Module`s with `TorchModuleWrapper` + # irrespective of whether they are lazy or not. + self.conv1 = TorchModuleWrapper( + nn.LazyConv2d(out_channels=32, kernel_size=(3, 3)) + ) + self.conv2 = TorchModuleWrapper( + nn.LazyConv2d(out_channels=64, kernel_size=(3, 3)) + ) + self.pool = TorchModuleWrapper(nn.MaxPool2d(kernel_size=(2, 2))) + self.flatten = TorchModuleWrapper(nn.Flatten()) + self.dropout = TorchModuleWrapper(nn.Dropout(p=0.5)) + self.fc = TorchModuleWrapper(nn.LazyLinear(10)) + + def call(self, inputs): + x = F.relu(self.conv1(inputs)) + x = self.pool(x) + x = F.relu(self.conv2(x)) + x = self.pool(x) + x = self.flatten(x) + x = self.dropout(x) + x = self.fc(x) + return F.softmax(x, dim=1) + + + model = Classifier() + model.build((1, 28, 28)) + print("Output shape:", model(torch.ones(1, 1, 28, 28).to("cuda")).shape) + + model.compile( + loss="sparse_categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"] + ) + model.fit(train_loader, epochs=5) + ``` + + """ + + def __init__(self, module, name=None): + super().__init__(name=name) + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.module = module.to(self.device) + self.lazy = isinstance(self.module, nn.modules.lazy.LazyModuleMixin) + if not self.lazy: + self.track_module_parameters() + + def parameters(self, recurse=True): + return self.module.parameters(recurse=not self.lazy) + + def track_module_parameters(self): + from keras_core.backend.torch import Variable + + for param in self.module.parameters(): + variable = Variable( + initializer=param, trainable=param.requires_grad + ) + variable._value = param + self._track_variable(variable) + self.built = True + + def build(self, input_shape, *args, **kwargs): + if not self.lazy: + self._build_by_run_for_single_pos_arg(args) + self._build_by_run_for_kwargs(kwargs) + else: + sample_input = torch.ones(*input_shape).to("cuda") + _ = self.module(sample_input) + self.track_module_parameters() + + def call(self, inputs, *args, **kwargs): + if not self.built: + self.build(inputs.shape[1:]) + return self.module.forward(inputs, *args, **kwargs) From d943ce7571e77c15429b89975651b24b2bcc357f Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Tue, 29 Aug 2023 19:39:47 +0000 Subject: [PATCH 08/11] update: using keras_core.backend.torch.core.compute_output_spec for flopless build --- keras_core/utils/torch_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/keras_core/utils/torch_utils.py b/keras_core/utils/torch_utils.py index 2dfcd08ff..ba6e18ced 100644 --- a/keras_core/utils/torch_utils.py +++ b/keras_core/utils/torch_utils.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn +import keras_core from keras_core.layers import Layer from keras_core.api_export import keras_core_export @@ -157,8 +158,9 @@ def build(self, input_shape, *args, **kwargs): self._build_by_run_for_single_pos_arg(args) self._build_by_run_for_kwargs(kwargs) else: - sample_input = torch.ones(*input_shape).to("cuda") - _ = self.module(sample_input) + # sample_input = torch.ones(*input_shape).to("cuda") + # _ = self.module(sample_input) + _ = keras_core.backend.torch.core.compute_output_spec(self.module) self.track_module_parameters() def call(self, inputs, *args, **kwargs): From 095492338505192d81651e683e072db81e9f332c Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Tue, 29 Aug 2023 19:48:53 +0000 Subject: [PATCH 09/11] update: build for lazy modules --- keras_core/utils/torch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_core/utils/torch_utils.py b/keras_core/utils/torch_utils.py index ba6e18ced..d5d5f9026 100644 --- a/keras_core/utils/torch_utils.py +++ b/keras_core/utils/torch_utils.py @@ -160,7 +160,7 @@ def build(self, input_shape, *args, **kwargs): else: # sample_input = torch.ones(*input_shape).to("cuda") # _ = self.module(sample_input) - _ = keras_core.backend.torch.core.compute_output_spec(self.module) + _ = keras_core.backend.torch.core.compute_output_spec(self.__call__) self.track_module_parameters() def call(self, inputs, *args, **kwargs): From 76a7b1ab8887ae0202b4eddf790b7fbda2a3c979 Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Tue, 29 Aug 2023 19:55:26 +0000 Subject: [PATCH 10/11] update: remove self.build() in self.call() --- keras_core/utils/torch_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/keras_core/utils/torch_utils.py b/keras_core/utils/torch_utils.py index d5d5f9026..1a3d0d6bc 100644 --- a/keras_core/utils/torch_utils.py +++ b/keras_core/utils/torch_utils.py @@ -153,7 +153,7 @@ def track_module_parameters(self): self._track_variable(variable) self.built = True - def build(self, input_shape, *args, **kwargs): + def build(self, *args, **kwargs): if not self.lazy: self._build_by_run_for_single_pos_arg(args) self._build_by_run_for_kwargs(kwargs) @@ -163,7 +163,5 @@ def build(self, input_shape, *args, **kwargs): _ = keras_core.backend.torch.core.compute_output_spec(self.__call__) self.track_module_parameters() - def call(self, inputs, *args, **kwargs): - if not self.built: - self.build(inputs.shape[1:]) - return self.module.forward(inputs, *args, **kwargs) + def call(self, inputs, **kwargs): + return self.module.forward(inputs, **kwargs) From bee9fb7b41b0e8ec7a43b759e5c0b74af44c42b5 Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Tue, 29 Aug 2023 19:58:10 +0000 Subject: [PATCH 11/11] update: removed unnecessary file --- keras_core/backend/torch/torch_module_wrapper.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 keras_core/backend/torch/torch_module_wrapper.py diff --git a/keras_core/backend/torch/torch_module_wrapper.py b/keras_core/backend/torch/torch_module_wrapper.py deleted file mode 100644 index e69de29bb..000000000