diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index 4953c1c81701..46e016a242ea 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -607,16 +607,19 @@ def wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, Sequence[Tensor]]: def _attribute_finder(root: Module, prefix: str, condition_yield: Callable[[Any], bool]): """Find attributes that satisfy the condition recursively""" + if isinstance(root, ModuleList): + for i, subitem in enumerate(root): + yield from _attribute_finder(subitem, prefix + f"{i}.", condition_yield) + return for name, item in root.__dict__.items(): if condition_yield(item): yield prefix + name, item elif isinstance(item, ModuleList): - for i, subitem in enumerate(item): - yield from _attribute_finder( - subitem, - prefix + name + f".{i}.", - condition_yield, - ) + yield from _attribute_finder( + item, + prefix + name + ".", + condition_yield, + ) elif isinstance(item, Module): yield from _attribute_finder( item, diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index 5ddc10505591..23250f28aa9f 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -700,5 +700,20 @@ def forward(x: R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dty assert_structural_equal(tvm_mod["forward"], forward) +def test_module_list(): + class Module(nn.Module): + def __init__(self): + self.layers = nn.ModuleList( + [nn.ModuleList([nn.Linear(4, 4, bias=False) for _ in range(2)]) for _ in range(1)] + ) + + def forward(self, x: nn.Tensor): + return self.layers(x) + + mod = Module() + named_params = dict(mod.named_parameters()) + assert ["layers.0.0.weight", "layers.0.1.weight"] == sorted(list(named_params.keys())) + + if __name__ == "__main__": tvm.testing.main()