Skip to content

Commit

Permalink
Make Keras layers autowrap torch modules
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Aug 31, 2023
1 parent 7a096ef commit 14b3755
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 7 deletions.
14 changes: 14 additions & 0 deletions keras_core/backend/torch/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,17 @@ def parameters(self, recurse=True):

def forward(self, *args, **kwargs):
return Operation.__call__(self, *args, **kwargs)

def _setattr_hook(self, name, value):
from keras_core.layers import Layer

if (
isinstance(value, torch.nn.Module)
and not isinstance(value, Layer)
and not name == "torch_params"
):
from keras_core.utils.torch_utils import TorchModuleWrapper

if not isinstance(self, TorchModuleWrapper):
value = TorchModuleWrapper(value)
return name, value
5 changes: 1 addition & 4 deletions keras_core/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,6 @@ def build(self, input_shape):
)
self.built = True

def _post_build(self):
"""Can be overridden for per backend post build actions."""
pass

def _lock_state(self):
"""Prevent further state updates, called automatically in `build()`."""
if not self._tracker.locked:
Expand Down Expand Up @@ -1229,6 +1225,7 @@ def __str__(self):

def __setattr__(self, name, value):
# Track Variables, Layers, Metrics, SeedGenerators.
name, value = self._setattr_hook(name, value)
if hasattr(self, "_tracker"):
value = self._tracker.track(value)
elif name != "_tracker":
Expand Down
9 changes: 9 additions & 0 deletions keras_core/ops/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,12 @@ def _get_node_attribute_at_index(self, node_index, attr, attr_name):
return values[0]
else:
return values

# Hooks for backend layer classes
def _post_build(self):
"""Can be overridden for per backend post build actions."""
pass

def _setattr_hook(self, name, value):
"""Can be overridden for per backend post build actions."""
return name, value
32 changes: 29 additions & 3 deletions keras_core/utils/torch_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from keras_core import backend
from keras_core import layers
from keras_core import models
from keras_core import testing
from keras_core.utils.torch_utils import TorchModuleWrapper
Expand All @@ -17,14 +18,39 @@ def call(self, x):
return self.fc(x)


@pytest.mark.skipif(
backend.backend() != "torch", reason="Requires torch backend"
)
class ClassifierWithNoSpecialCasing(models.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fc = torch.nn.Linear(2, 4)
self.fc2 = layers.Dense(2)

def call(self, x):
return self.fc(self.fc2(x))


class TorchUtilsTest(testing.TestCase):
@pytest.mark.skipif(
backend.backend() != "torch", reason="Requires torch backend"
)
def test_basic_usage(self):
model = Classifier()
self.assertEqual(len(model.layers), 1)
self.assertEqual(len(model.trainable_weights), 2)
model(np.random.random((3, 2)))
model.compile(optimizer="sgd", loss="mse")
model.fit(np.random.random((3, 2)), np.random.random((3, 4)))

@pytest.mark.skipif(
backend.backend() != "torch", reason="Requires torch backend"
)
def test_module_autowrapping(self):
model = ClassifierWithNoSpecialCasing()
self.assertTrue(isinstance(model.fc, TorchModuleWrapper))
self.assertFalse(isinstance(model.fc2, TorchModuleWrapper))
self.assertEqual(len(model.fc.trainable_weights), 2)
model(np.random.random((3, 2)))
self.assertEqual(len(model._layers), 2)
self.assertEqual(len(model.fc2.trainable_weights), 2)
self.assertEqual(len(model.trainable_weights), 4)
model.compile(optimizer="sgd", loss="mse")
model.fit(np.random.random((3, 2)), np.random.random((3, 4)))

0 comments on commit 14b3755

Please sign in to comment.