Skip to content

Commit

Permalink
[Relax] Support nested ModuleList in nn.Module (#16971)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored May 7, 2024
1 parent 28d32b5 commit 819b002
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
15 changes: 9 additions & 6 deletions python/tvm/relax/frontend/nn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,16 +607,19 @@ def wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, Sequence[Tensor]]:

def _attribute_finder(root: Module, prefix: str, condition_yield: Callable[[Any], bool]):
"""Find attributes that satisfy the condition recursively"""
if isinstance(root, ModuleList):
for i, subitem in enumerate(root):
yield from _attribute_finder(subitem, prefix + f"{i}.", condition_yield)
return
for name, item in root.__dict__.items():
if condition_yield(item):
yield prefix + name, item
elif isinstance(item, ModuleList):
for i, subitem in enumerate(item):
yield from _attribute_finder(
subitem,
prefix + name + f".{i}.",
condition_yield,
)
yield from _attribute_finder(
item,
prefix + name + ".",
condition_yield,
)
elif isinstance(item, Module):
yield from _attribute_finder(
item,
Expand Down
15 changes: 15 additions & 0 deletions tests/python/relax/test_frontend_nn_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,5 +700,20 @@ def forward(x: R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dty
assert_structural_equal(tvm_mod["forward"], forward)


def test_module_list():
class Module(nn.Module):
def __init__(self):
self.layers = nn.ModuleList(
[nn.ModuleList([nn.Linear(4, 4, bias=False) for _ in range(2)]) for _ in range(1)]
)

def forward(self, x: nn.Tensor):
return self.layers(x)

mod = Module()
named_params = dict(mod.named_parameters())
assert ["layers.0.0.weight", "layers.0.1.weight"] == sorted(list(named_params.keys()))


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 819b002

Please sign in to comment.