Skip to content

Commit

Permalink
add _modules and _parameters property
Browse files Browse the repository at this point in the history
  • Loading branch information
cjld committed Mar 22, 2022
1 parent 99d6d6b commit 31fbf4a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
10 changes: 9 additions & 1 deletion python/jittor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

__version__ = '1.3.1.49'
__version__ = '1.3.1.50'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
Expand Down Expand Up @@ -957,6 +957,14 @@ def callback_leave(parents, k, v, n):
self.dfs([], "", callback, callback_leave)
return ms

@property
def _modules(self):
return { k:v for k,v in self.__dict__.items() if isinstance(v, Module) }

@property
def _parameters(self):
return { k:v for k,v in self.__dict__.items() if isinstance(v, Var) }

def requires_grad_(self, requires_grad=True):
self._requires_grad = requires_grad
self._place_hooker()
Expand Down
13 changes: 13 additions & 0 deletions python/jittor/test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,18 @@ def test_module(self):
a.y = 2
assert a.y == 2

def test_modules(self):
a = jt.Module()
a.x = jt.Module()
a.y = jt.Module()
a.a = jt.array([1,2,3])
a.b = jt.array([1,2,3])
assert a._modules.keys() == ["x", "y"]
assert a._modules['x'] is a.x
assert a._modules['y'] is a.y
assert a._parameters.keys() == ['a', 'b']
assert a._parameters['a'] is a.a
assert a._parameters['b'] is a.b

if __name__ == "__main__":
unittest.main()

0 comments on commit 31fbf4a

Please sign in to comment.