-
Notifications
You must be signed in to change notification settings - Fork 335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Multiagent nets problems with SAC #1957
Comments
The second one concerns reloading from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.modules.models.multiagent import MultiAgentMLP
if __name__ == "__main__":
actor_net = MultiAgentMLP(
n_agent_inputs=4,
n_agent_outputs=6,
n_agents=2,
centralised=False,
share_params=False,
device="cpu",
depth=2,
num_cells=256,
activation_class=nn.Tanh,
)
policy_module = TensorDictModule(
actor_net,
in_keys=[("agents", "observation")],
out_keys=[("agents", "action")],
)
dict = policy_module.state_dict()
policy_module.load_state_dict(dict) Traceback (most recent call last):
File "/Users/Matteo/PycharmProjects/torchrl/examples/multiagent/prova.py", line 25, in <module>
policy_module.load_state_dict(dict)
File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2138, in load_state_dict
load(self, state_dict)
File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2126, in load
load(child, child_state_dict, child_prefix)
File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2126, in load
load(child, child_state_dict, child_prefix)
File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2120, in load
module._load_from_state_dict(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/nn/params.py", line 994, in _load_from_state_dict
TensorDict(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2455, in get
return self._get_tuple(key, default=default)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1647, in _get_tuple
first = self._get_str(key[0], default)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1643, in _get_str
return self._default_get(first_key, default)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2433, in _default_get
raise KeyError(
KeyError: 'key "module.params" not found in TensorDict with keys [\'module\']' EDIT: moved to #1960 |
For the first can you provide a minimal reproducible example too?Thanks! |
from tensordict.nn import TensorDictModule, NormalParamExtractor
from torch import nn
from torchrl.envs import VmasEnv
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import SACLoss, SoftUpdate, ValueEstimators
if __name__ == "__main__":
env = VmasEnv(
scenario="navigation",
num_envs=4,
continuous_actions=True,
max_steps=100,
device="cpu",
seed=0,
)
actor_net = nn.Sequential(
MultiAgentMLP(
n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
n_agent_outputs=2 * env.action_spec.shape[-1],
n_agents=env.n_agents,
centralised=False,
share_params=False,
device="cpu",
depth=2,
num_cells=256,
activation_class=nn.Tanh,
),
NormalParamExtractor(),
)
policy_module = TensorDictModule(
actor_net,
in_keys=[("agents", "observation")],
out_keys=[("agents", "loc"), ("agents", "scale")],
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.unbatched_action_spec,
in_keys=[("agents", "loc"), ("agents", "scale")],
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"min": env.unbatched_action_spec[("agents", "action")].space.low,
"max": env.unbatched_action_spec[("agents", "action")].space.high,
},
return_log_prob=True,
)
# Critic
module = MultiAgentMLP(
n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1]
+ env.action_spec.shape[-1], # Q critic takes action and value
n_agent_outputs=1,
n_agents=env.n_agents,
centralised=False,
share_params=False,
device="cpu",
depth=2,
num_cells=256,
activation_class=nn.Tanh,
)
value_module = ValueOperator(
module=module,
in_keys=[("agents", "observation"), env.action_key],
out_keys=[("agents", "state_action_value")],
)
loss_module = SACLoss(
actor_network=policy,
qvalue_network=value_module,
delay_qvalue=True,
action_spec=env.unbatched_action_spec,
)
loss_module.set_keys(
state_action_value=("agents", "state_action_value"),
action=env.action_key,
reward=env.reward_key,
done=("agents", "done"),
terminated=("agents", "terminated"),
) Traceback (most recent call last):
File "/Users/Matteo/PycharmProjects/torchrl/examples/multiagent/prova.py", line 71, in <module>
loss_module = SACLoss(
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/objectives/sac.py", line 327, in __init__
self.convert_to_functional(
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/objectives/common.py", line 292, in convert_to_functional
with params.apply(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/utils.py", line 1152, in new_func
out = func(_self, *args, **kwargs)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 595, in to_module
return self._to_module(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 411, in _to_module
local_out = value._to_module(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 411, in _to_module
local_out = value._to_module(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 334, in _to_module
module.update(self)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/utils.py", line 1114, in new_func
return func(self, *args, **kwargs)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/nn/params.py", line 417, in update
TensorDictBase.update(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2613, in update
target.update(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 2620, in update
self._set_tuple(
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1515, in _set_tuple
return self._set_str(key[0], value, inplace=inplace, validated=validated)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/_td.py", line 1485, in _set_str
value = self._validate_value(value, check_shape=True)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/base.py", line 4378, in _validate_value
raise RuntimeError(
RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([4]) and value.shape=torch.Size([2, 4, 256]). |
Something with just the network and no env? Those modules don't require anything else in torchrl stack no? |
The env is used to set up the specs, the crash happens in the loss which is just fed modules. The env in the script is never called I'll try to rewrite more |
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.modules import ValueOperator
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import SACLoss
if __name__ == "__main__":
obs_size = 4
action_size = 6
n_agents = 3
actor_net = MultiAgentMLP(
n_agent_inputs=obs_size,
n_agent_outputs=action_size,
n_agents=n_agents,
centralised=False,
share_params=False,
device="cpu",
depth=2,
num_cells=256,
activation_class=nn.Tanh,
)
policy = TensorDictModule(
actor_net,
in_keys=[("agents", "observation")],
out_keys=[("agents", "action"), ("agents", "scale")],
)
# Critic
module = MultiAgentMLP(
n_agent_inputs=obs_size + action_size, # Q critic takes action and value
n_agent_outputs=1,
n_agents=n_agents,
centralised=False,
share_params=False,
device="cpu",
depth=2,
num_cells=256,
activation_class=nn.Tanh,
)
value_module = ValueOperator(
module=module,
in_keys=[("agents", "observation"), ("agents", "action")],
out_keys=[("agents", "state_action_value")],
)
loss_module = SACLoss(
actor_network=policy,
qvalue_network=value_module,
delay_qvalue=True,
) |
Since those are two separate issues, can you make, well... two separate issues? 😀 |
I opened #1960 for the restoring problems. I am keeping this issue for the SAC-related issue. |
I am experiencing a new series of bugs in the BenchMARL library after #1921.
I'll try to list them here as I unravel them.
A first one can be observed when running
python examples/multiagent/sac.py model.shared_parameters=False
This seems related to the use of ensembles in SAC
The text was updated successfully, but these errors were encountered: