Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Multiagent nets problems with SAC #1957

Closed
matteobettini opened this issue Feb 23, 2024 · 8 comments · Fixed by pytorch/tensordict#688
Closed

[BUG] Multiagent nets problems with SAC #1957

matteobettini opened this issue Feb 23, 2024 · 8 comments · Fixed by pytorch/tensordict#688
Assignees
Labels
bug Something isn't working

Comments

@matteobettini
Copy link
Contributor

matteobettini commented Feb 23, 2024

I am experiencing a new series of bugs in the BenchMARL library after #1921.

I'll try to list them here as I unravel them.

A first one can be observed when running python examples/multiagent/sac.py model.shared_parameters=False

This seems related to the use of ensembles in SAC

Traceback (most recent call last):
  File "/Users/Matteo/PycharmProjects/torchrl/examples/multiagent/sac.py", line 193, in train
    loss_module = SACLoss(
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/objectives/sac.py", line 327, in __init__
    self.convert_to_functional(
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/objectives/common.py", line 292, in convert_to_functional
    with params.apply(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/utils.py", line 1152, in new_func
    out = func(_self, *args, **kwargs)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 595, in to_module
    return self._to_module(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 411, in _to_module
    local_out = value._to_module(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 411, in _to_module
    local_out = value._to_module(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 334, in _to_module
    module.update(self)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/utils.py", line 1114, in new_func
    return func(self, *args, **kwargs)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/nn/params.py", line 417, in update
    TensorDictBase.update(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2613, in update
    target.update(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2620, in update
    self._set_tuple(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1515, in _set_tuple
    return self._set_str(key[0], value, inplace=inplace, validated=validated)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1485, in _set_str
    value = self._validate_value(value, check_shape=True)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 4378, in _validate_value
    raise RuntimeError(
RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([3]) and value.shape=torch.Size([2, 3, 256, 54]).

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
@matteobettini matteobettini added the bug Something isn't working label Feb 23, 2024
@matteobettini matteobettini changed the title [BUG] Several issues with new multiagent nets from #1921 [BUG] Issues with multiagent nets from #1921 Feb 23, 2024
@matteobettini
Copy link
Contributor Author

matteobettini commented Feb 23, 2024

The second one concerns reloading

from tensordict.nn import TensorDictModule
from torch import nn

from torchrl.modules.models.multiagent import MultiAgentMLP

if __name__ == "__main__":
    actor_net = MultiAgentMLP(
        n_agent_inputs=4,
        n_agent_outputs=6,
        n_agents=2,
        centralised=False,
        share_params=False,
        device="cpu",
        depth=2,
        num_cells=256,
        activation_class=nn.Tanh,
    )

    policy_module = TensorDictModule(
        actor_net,
        in_keys=[("agents", "observation")],
        out_keys=[("agents", "action")],
    )
    dict = policy_module.state_dict()
    policy_module.load_state_dict(dict)
Traceback (most recent call last):
  File "/Users/Matteo/PycharmProjects/torchrl/examples/multiagent/prova.py", line 25, in <module>
    policy_module.load_state_dict(dict)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2138, in load_state_dict
    load(self, state_dict)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2126, in load
    load(child, child_state_dict, child_prefix)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2126, in load
    load(child, child_state_dict, child_prefix)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2120, in load
    module._load_from_state_dict(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/nn/params.py", line 994, in _load_from_state_dict
    TensorDict(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2455, in get
    return self._get_tuple(key, default=default)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1647, in _get_tuple
    first = self._get_str(key[0], default)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1643, in _get_str
    return self._default_get(first_key, default)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2433, in _default_get
    raise KeyError(
KeyError: 'key "module.params" not found in TensorDict with keys [\'module\']'

EDIT: moved to #1960

@vmoens
Copy link
Contributor

vmoens commented Feb 23, 2024

For the first can you provide a minimal reproducible example too?Thanks!

@matteobettini
Copy link
Contributor Author

For the first can you provide a minimal reproducible example too?Thanks!

from tensordict.nn import TensorDictModule, NormalParamExtractor
from torch import nn

from torchrl.envs import VmasEnv
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import SACLoss, SoftUpdate, ValueEstimators

if __name__ == "__main__":
    env = VmasEnv(
        scenario="navigation",
        num_envs=4,
        continuous_actions=True,
        max_steps=100,
        device="cpu",
        seed=0,
    )

    actor_net = nn.Sequential(
        MultiAgentMLP(
            n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
            n_agent_outputs=2 * env.action_spec.shape[-1],
            n_agents=env.n_agents,
            centralised=False,
            share_params=False,
            device="cpu",
            depth=2,
            num_cells=256,
            activation_class=nn.Tanh,
        ),
        NormalParamExtractor(),
    )
    policy_module = TensorDictModule(
        actor_net,
        in_keys=[("agents", "observation")],
        out_keys=[("agents", "loc"), ("agents", "scale")],
    )

    policy = ProbabilisticActor(
        module=policy_module,
        spec=env.unbatched_action_spec,
        in_keys=[("agents", "loc"), ("agents", "scale")],
        out_keys=[env.action_key],
        distribution_class=TanhNormal,
        distribution_kwargs={
            "min": env.unbatched_action_spec[("agents", "action")].space.low,
            "max": env.unbatched_action_spec[("agents", "action")].space.high,
        },
        return_log_prob=True,
    )

    # Critic
    module = MultiAgentMLP(
        n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1]
        + env.action_spec.shape[-1],  # Q critic takes action and value
        n_agent_outputs=1,
        n_agents=env.n_agents,
        centralised=False,
        share_params=False,
        device="cpu",
        depth=2,
        num_cells=256,
        activation_class=nn.Tanh,
    )
    value_module = ValueOperator(
        module=module,
        in_keys=[("agents", "observation"), env.action_key],
        out_keys=[("agents", "state_action_value")],
    )

    loss_module = SACLoss(
        actor_network=policy,
        qvalue_network=value_module,
        delay_qvalue=True,
        action_spec=env.unbatched_action_spec,
    )
    loss_module.set_keys(
        state_action_value=("agents", "state_action_value"),
        action=env.action_key,
        reward=env.reward_key,
        done=("agents", "done"),
        terminated=("agents", "terminated"),
    )
Traceback (most recent call last):
  File "/Users/Matteo/PycharmProjects/torchrl/examples/multiagent/prova.py", line 71, in <module>
    loss_module = SACLoss(
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/objectives/sac.py", line 327, in __init__
    self.convert_to_functional(
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/objectives/common.py", line 292, in convert_to_functional
    with params.apply(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/utils.py", line 1152, in new_func
    out = func(_self, *args, **kwargs)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 595, in to_module
    return self._to_module(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 411, in _to_module
    local_out = value._to_module(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 411, in _to_module
    local_out = value._to_module(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 334, in _to_module
    module.update(self)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/utils.py", line 1114, in new_func
    return func(self, *args, **kwargs)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/nn/params.py", line 417, in update
    TensorDictBase.update(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2613, in update
    target.update(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2620, in update
    self._set_tuple(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1515, in _set_tuple
    return self._set_str(key[0], value, inplace=inplace, validated=validated)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1485, in _set_str
    value = self._validate_value(value, check_shape=True)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 4378, in _validate_value
    raise RuntimeError(
RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([4]) and value.shape=torch.Size([2, 4, 256]).

@vmoens
Copy link
Contributor

vmoens commented Feb 23, 2024

Something with just the network and no env? Those modules don't require anything else in torchrl stack no?

@matteobettini
Copy link
Contributor Author

matteobettini commented Feb 23, 2024

Something with just the network and no env? Those modules don't require anything else in torchrl stack no?

The env is used to set up the specs, the crash happens in the loss which is just fed modules. The env in the script is never called

I'll try to rewrite more

@matteobettini
Copy link
Contributor Author

from tensordict.nn import TensorDictModule
from torch import nn


from torchrl.modules import ValueOperator
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import SACLoss

if __name__ == "__main__":
    obs_size = 4
    action_size = 6
    n_agents = 3

    actor_net = MultiAgentMLP(
        n_agent_inputs=obs_size,
        n_agent_outputs=action_size,
        n_agents=n_agents,
        centralised=False,
        share_params=False,
        device="cpu",
        depth=2,
        num_cells=256,
        activation_class=nn.Tanh,
    )
    policy = TensorDictModule(
        actor_net,
        in_keys=[("agents", "observation")],
        out_keys=[("agents", "action"), ("agents", "scale")],
    )

    # Critic
    module = MultiAgentMLP(
        n_agent_inputs=obs_size + action_size,  # Q critic takes action and value
        n_agent_outputs=1,
        n_agents=n_agents,
        centralised=False,
        share_params=False,
        device="cpu",
        depth=2,
        num_cells=256,
        activation_class=nn.Tanh,
    )
    value_module = ValueOperator(
        module=module,
        in_keys=[("agents", "observation"), ("agents", "action")],
        out_keys=[("agents", "state_action_value")],
    )

    loss_module = SACLoss(
        actor_network=policy,
        qvalue_network=value_module,
        delay_qvalue=True,
    )

@vmoens
Copy link
Contributor

vmoens commented Feb 23, 2024

Since those are two separate issues, can you make, well... two separate issues? 😀

@matteobettini matteobettini changed the title [BUG] Issues with multiagent nets from #1921 [BUG] Multiagent nets problems with SAC Feb 23, 2024
@matteobettini
Copy link
Contributor Author

I opened #1960 for the restoring problems.

I am keeping this issue for the SAC-related issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants