Skip to content

Commit

Permalink
Add TorchModuleWrapper (#779)
Browse files Browse the repository at this point in the history
* add: TorchModuleWarpper

* chore: make ci happy

* fix: typo in TorchModuleWrapper

* update: Variable import

* update: TorchModuleWrapper

* update: manage device placement

* refactor: TorchModuleWrapper to keras_core.utils.torch_utils

* update: using keras_core.backend.torch.core.compute_output_spec for flopless build

* update: build for lazy modules

* update: remove self.build() in self.call()

* update: removed unnecessary file
  • Loading branch information
soumik12345 authored Aug 30, 2023
1 parent e63e94c commit 5735fad
Showing 1 changed file with 167 additions and 0 deletions.
167 changes: 167 additions & 0 deletions keras_core/utils/torch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import torch
import torch.nn as nn

import keras_core
from keras_core.layers import Layer
from keras_core.api_export import keras_core_export


@keras_core_export(["keras_core.utils.torch_utils.TorchModuleWrapper"])
class TorchModuleWrapper(Layer):
"""Torch module wrapper layer.
`TorchModuleWrapper` is an abstraction that can be wrapped around a
`torch.nn.Module` to make its parameters trackable as a
`keras_core.layers.Layer`. It works with both vanilla and lazy PyTorch
modules.
Args:
module: torch.nn.Module, A vanilla or lazy PyTorch neural network
module.
name: The name of the layer (string).
References:
- [PyTorch docs for `torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) # noqa: E501
- [PyTorch docs for `LazyModuleMixin`](https://pytorch.org/docs/stable/generated/torch.nn.modules.lazy.LazyModuleMixin.html) # noqa: E501
Examples:
Here's an example of how the `TorchModuleWrapper` can be used with vanilla
PyTorch modules.
```python
import torch.nn as nn
import torch.nn.functional as F
import keras_core
from keras_core.backend.torch import TorchModuleWrapper
class Classifier(keras_core.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Wrap all `torch.nn.Module`s with `TorchModuleWrapper`
self.conv1 = TorchModuleWrapper(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3))
)
self.conv2 = TorchModuleWrapper(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3))
)
self.pool = TorchModuleWrapper(
nn.MaxPool2d(kernel_size=(2, 2))
)
self.flatten = TorchModuleWrapper(nn.Flatten())
self.dropout = TorchModuleWrapper(nn.Dropout(p=0.5))
self.fc = TorchModuleWrapper(nn.Linear(1600, 10))
def call(self, inputs):
x = F.relu(self.conv1(inputs))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.pool(x)
x = self.flatten(x)
x = self.dropout(x)
x = self.fc(x)
return F.softmax(x, dim=1)
model = Classifier()
model.build((1, 28, 28))
print("Output shape:", model(torch.ones(1, 1, 28, 28).to("cuda")).shape)
model.compile(
loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"]
)
model.fit(train_loader, epochs=5)
```
Here's an example of how the `TorchModuleWrapper` can be used with PyTorch
Lazy modules.
```python
import torch.nn as nn
import torch.nn.functional as F
import keras_core
from keras_core.backend.torch import TorchModuleWrapper
class LazyClassifier(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# You can wrap all `torch.nn.Module`s with `TorchModuleWrapper`
# irrespective of whether they are lazy or not.
self.conv1 = TorchModuleWrapper(
nn.LazyConv2d(out_channels=32, kernel_size=(3, 3))
)
self.conv2 = TorchModuleWrapper(
nn.LazyConv2d(out_channels=64, kernel_size=(3, 3))
)
self.pool = TorchModuleWrapper(nn.MaxPool2d(kernel_size=(2, 2)))
self.flatten = TorchModuleWrapper(nn.Flatten())
self.dropout = TorchModuleWrapper(nn.Dropout(p=0.5))
self.fc = TorchModuleWrapper(nn.LazyLinear(10))
def call(self, inputs):
x = F.relu(self.conv1(inputs))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.pool(x)
x = self.flatten(x)
x = self.dropout(x)
x = self.fc(x)
return F.softmax(x, dim=1)
model = Classifier()
model.build((1, 28, 28))
print("Output shape:", model(torch.ones(1, 1, 28, 28).to("cuda")).shape)
model.compile(
loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"]
)
model.fit(train_loader, epochs=5)
```
"""

def __init__(self, module, name=None):
super().__init__(name=name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.module = module.to(self.device)
self.lazy = isinstance(self.module, nn.modules.lazy.LazyModuleMixin)
if not self.lazy:
self.track_module_parameters()

def parameters(self, recurse=True):
return self.module.parameters(recurse=not self.lazy)

def track_module_parameters(self):
from keras_core.backend.torch import Variable

for param in self.module.parameters():
variable = Variable(
initializer=param, trainable=param.requires_grad
)
variable._value = param
self._track_variable(variable)
self.built = True

def build(self, *args, **kwargs):
if not self.lazy:
self._build_by_run_for_single_pos_arg(args)
self._build_by_run_for_kwargs(kwargs)
else:
# sample_input = torch.ones(*input_shape).to("cuda")
# _ = self.module(sample_input)
_ = keras_core.backend.torch.core.compute_output_spec(self.__call__)
self.track_module_parameters()

def call(self, inputs, **kwargs):
return self.module.forward(inputs, **kwargs)

0 comments on commit 5735fad

Please sign in to comment.