Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Restoring multiagent nets #1960

Closed
matteobettini opened this issue Feb 23, 2024 · 1 comment · Fixed by pytorch/tensordict#689
Closed

[BUG] Restoring multiagent nets #1960

matteobettini opened this issue Feb 23, 2024 · 1 comment · Fixed by pytorch/tensordict#689
Assignees
Labels
bug Something isn't working

Comments

@matteobettini
Copy link
Contributor

Bug that prevents restoring multiagent networks after #1921

from tensordict.nn import TensorDictModule
from torch import nn

from torchrl.modules.models.multiagent import MultiAgentMLP

if __name__ == "__main__":
    actor_net = MultiAgentMLP(
        n_agent_inputs=4,
        n_agent_outputs=6,
        n_agents=2,
        centralised=False,
        share_params=False,
        device="cpu",
        depth=2,
        num_cells=256,
        activation_class=nn.Tanh,
    )

    policy_module = TensorDictModule(
        actor_net,
        in_keys=[("agents", "observation")],
        out_keys=[("agents", "action")],
    )
    dict = policy_module.state_dict()
    policy_module.load_state_dict(dict)
Traceback (most recent call last):
  File "/Users/Matteo/PycharmProjects/torchrl/examples/multiagent/prova.py", line 25, in <module>
    policy_module.load_state_dict(dict)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2138, in load_state_dict
    load(self, state_dict)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2126, in load
    load(child, child_state_dict, child_prefix)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2126, in load
    load(child, child_state_dict, child_prefix)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2120, in load
    module._load_from_state_dict(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/nn/params.py", line 994, in _load_from_state_dict
    TensorDict(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2455, in get
    return self._get_tuple(key, default=default)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1647, in _get_tuple
    first = self._get_str(key[0], default)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1643, in _get_str
    return self._default_get(first_key, default)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2433, in _default_get
    raise KeyError(
KeyError: 'key "module.params" not found in TensorDict with keys [\'module\']'
@vmoens
Copy link
Contributor

vmoens commented Feb 23, 2024

On it
I would consider not using state dict by the way :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants