From d7fecbbb9bff254e557a668bacc558a975827719 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 10:43:30 +0100 Subject: [PATCH 01/28] Initial commit --- examples/smac/env.py | 177 ++++++++++++++++++++++++++++++++++++++++ examples/smac/policy.py | 16 ++++ 2 files changed, 193 insertions(+) create mode 100644 examples/smac/env.py create mode 100644 examples/smac/policy.py diff --git a/examples/smac/env.py b/examples/smac/env.py new file mode 100644 index 00000000000..142d2c21643 --- /dev/null +++ b/examples/smac/env.py @@ -0,0 +1,177 @@ +from dataclasses import dataclass +from typing import Optional, Union + +import numpy as np +import torch +from smac.env import StarCraft2Env + +from torchrl.data import TensorDict, NdUnboundedContinuousTensorSpec, \ + UnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec +from torchrl.data.tensor_specs import _default_dtype_and_device, DiscreteBox, \ + DEVICE_TYPING +from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.envs.common import GymLikeEnv + + +@dataclass(repr=False) +class NdOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): + def __init__( + self, + n: int, + d: int, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[str, torch.dtype]] = torch.long, + use_register: bool = False, + ): + dtype, device = _default_dtype_and_device(dtype, device) + self.use_register = use_register + space = DiscreteBox( + n, + ) + self.d = d + shape = torch.Size((d, space.n,)) + super(OneHotDiscreteTensorSpec, self).__init__(shape, space, device, + dtype, "discrete") + + def rand(self, shape=torch.Size([])) -> torch.Tensor: + return torch.nn.functional.gumbel_softmax( + torch.rand(*shape, self.d, self.space.n, device=self.device), + hard=True, + dim=-1, + ).to(torch.long) + + +@dataclass(repr=False) +class CustomNdOneHotDiscreteTensorSpec(NdOneHotDiscreteTensorSpec): + def __init__( + self, + mask: torch.Tensor, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[str, torch.dtype]] = torch.long, + use_register: bool = False, + ): + self.mask = mask + *_, d, n = mask.shape + + dtype, device = _default_dtype_and_device(dtype, device) + self.use_register = use_register + space = DiscreteBox( + n, + ) + self.d = d + shape = torch.Size((d, space.n,)) + super(OneHotDiscreteTensorSpec, self).__init__(shape, space, device, + dtype, "discrete") + + def to(self, dest): + out = super().to(dest) + out.mask = self.mask.to(dest) + return out + + def rand(self, shape=torch.Size([])) -> torch.Tensor: + mask = self.mask.to(torch.float) + return torch.nn.functional.gumbel_softmax( + mask.log(), + hard=True, + dim=-1, + ).to(torch.long) + + def is_in(self, value): + return ((self.mask - value) >= 0).all() + + +class SCEnv(GymLikeEnv): + available_envs = ["8m"] + # TODO: add to parent class + supplementary_keys = ["available_actions"] + + @property + def observation_spec(self): + info = self._env.get_env_info() + dim = (info['n_agents'], info['obs_shape']) + return NdUnboundedContinuousTensorSpec(dim) + + @property + def action_spec(self): + info = self._env.get_env_info() + return CustomNdOneHotDiscreteTensorSpec( + torch.tensor(self._env.get_avail_actions())) + + @property + def reward_spec(self): + return UnboundedContinuousTensorSpec() + + def _build_env( + self, map_name: str, taskname=None, **kwargs + ) -> None: + if taskname: + raise RuntimeError + + env = StarCraft2Env(map_name=map_name) + self._env = env + return env + + def _output_transform(self, step_result): + reward, done, *other = step_result + obs = self._env.get_obs() + available_actions = self._env.get_avail_actions() + return obs, reward, done, available_actions, *other + + def _reset( + self, tensor_dict: Optional[_TensorDict] = None, **kwargs + ) -> _TensorDict: + obs = self._env.get_obs() + + tensor_dict_out = TensorDict( + source=self._read_obs(obs), batch_size=self.batch_size + ) + self._is_done = torch.zeros(1, dtype=torch.bool) + tensor_dict_out.set("done", self._is_done) + available_actions = self._env.get_avail_actions() + tensor_dict_out.set("available_actions", available_actions) + return tensor_dict_out + + def _init_env(self, seed=None): + self._env.reset() + if seed is not None: + self.set_seed(seed) + + # TODO: check that actions match avail + def _action_transform(self, action): + action_np = self.action_spec.to_numpy(action) + return action_np + + # TODO: move to GymLike + def _step(self, tensor_dict: _TensorDict) -> _TensorDict: + action = tensor_dict.get("action") + action_np = self._action_transform(action) + + reward = 0.0 + for _ in range(self.wrapper_frame_skip): + obs, _reward, done, *info = self._output_transform( + self._env.step(action_np) + ) + if _reward is None: + _reward = 0.0 + reward += _reward + if done: + break + + obs_dict = self._read_obs(obs) + + if reward is None: + reward = np.nan + reward = self._to_tensor(reward, dtype=self.reward_spec.dtype) + done = self._to_tensor(done, dtype=torch.bool) + self._is_done = done + self._current_tensordict = obs_dict + + tensor_dict_out = TensorDict({}, batch_size=tensor_dict.batch_size) + for key, value in obs_dict.items(): + tensor_dict_out.set(f"next_{key}", value) + tensor_dict_out.set("reward", reward) + tensor_dict_out.set("done", done) + for k, value in zip(self.supplementary_keys, info): + tensor_dict_out.set(k, value) + + return tensor_dict_out diff --git a/examples/smac/policy.py b/examples/smac/policy.py new file mode 100644 index 00000000000..038d089ae37 --- /dev/null +++ b/examples/smac/policy.py @@ -0,0 +1,16 @@ +from torch import nn, Tensor + +class MaskedLogitPolicy(nn.Module): + def __init__(self, policy_module): + super().__init__() + self.policy_module = policy_module + + def forward(self, *inputs): + *inputs, mask = inputs + outputs = self.policy_module(*inputs) + if isinstance(outputs, Tensor): + outputs = (outputs,) + # first output is logits + outputs[0].masked_fill_(mask.expand_as(outputs[0], -float("inf"))) + return tuple(outputs) + From 9c476efaf17b7ce5480510ad5bb1cc33359a06c6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 11:27:14 +0100 Subject: [PATCH 02/28] cuda 10.2 --- .circleci/unittest/linux/scripts/environment.yml | 1 + examples/smac/smac_test.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) create mode 100644 examples/smac/smac_test.py diff --git a/.circleci/unittest/linux/scripts/environment.yml b/.circleci/unittest/linux/scripts/environment.yml index 0af0ab2ca28..b64ed32f221 100644 --- a/.circleci/unittest/linux/scripts/environment.yml +++ b/.circleci/unittest/linux/scripts/environment.yml @@ -3,6 +3,7 @@ channels: - defaults dependencies: - pip + - ninja - cmake >= 3.18 - pip: - hypothesis diff --git a/examples/smac/smac_test.py b/examples/smac/smac_test.py new file mode 100644 index 00000000000..e6c79db05fe --- /dev/null +++ b/examples/smac/smac_test.py @@ -0,0 +1,16 @@ +from env import SCEnv +from torchrl.envs import TransformedEnv, ObservationNorm + +if __name__ == "__main__": + # create an env + env = SCEnv("8m") + + # reset + td = env.reset() + print("tensordict after reset: ") + print(td) + + # apply a sequence of transforms + env = TransformedEnv(env, ObservationNorm(0, 1, standard_normal=True)) + + # From 0708191e43b9103c59933b6be34c3b8d818a0d12 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 11:49:13 +0100 Subject: [PATCH 03/28] script --- examples/smac/smac_test.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/examples/smac/smac_test.py b/examples/smac/smac_test.py index e6c79db05fe..6717cc82247 100644 --- a/examples/smac/smac_test.py +++ b/examples/smac/smac_test.py @@ -1,5 +1,8 @@ from env import SCEnv +from examples.smac.policy import MaskedLogitPolicy from torchrl.envs import TransformedEnv, ObservationNorm +from torchrl.modules import ProbabilisticTDModule, OneHotCategorical +from torch import nn if __name__ == "__main__": # create an env @@ -13,4 +16,20 @@ # apply a sequence of transforms env = TransformedEnv(env, ObservationNorm(0, 1, standard_normal=True)) - # + # Get policy + policy = nn.LazyLinear(env.action_spec.shape[-1]) + policy_wrap = MaskedLogitPolicy(policy) + policy_td_module = ProbabilisticTDModule( + policy_wrap, + in_keys=["observation", "available_actions"], + out_keys=["action"], + distribution_class=OneHotCategorical, + ) + + # Test the policy + policy_td_module(td) + print(td) + + # check that an ation can be performed in the env with this + env.step(td) + print(td) From 259cecd2c48e68475792a7834733cd3877fbf29e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 11:52:02 +0100 Subject: [PATCH 04/28] script --- examples/smac/smac_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/smac/smac_test.py b/examples/smac/smac_test.py index 6717cc82247..7164b001968 100644 --- a/examples/smac/smac_test.py +++ b/examples/smac/smac_test.py @@ -20,7 +20,8 @@ policy = nn.LazyLinear(env.action_spec.shape[-1]) policy_wrap = MaskedLogitPolicy(policy) policy_td_module = ProbabilisticTDModule( - policy_wrap, + module=policy_wrap, + spec=None, in_keys=["observation", "available_actions"], out_keys=["action"], distribution_class=OneHotCategorical, From 1aa333cd6df34729b2ffd6c4f02c5fe085fedd89 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 11:53:49 +0100 Subject: [PATCH 05/28] script --- examples/smac/policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/smac/policy.py b/examples/smac/policy.py index 038d089ae37..24b81cd70ae 100644 --- a/examples/smac/policy.py +++ b/examples/smac/policy.py @@ -11,6 +11,6 @@ def forward(self, *inputs): if isinstance(outputs, Tensor): outputs = (outputs,) # first output is logits - outputs[0].masked_fill_(mask.expand_as(outputs[0], -float("inf"))) + outputs[0].masked_fill_(mask.expand_as(outputs[0]), -float("inf")) return tuple(outputs) From 59853e036815dcbcca3cd00d4d2512f38b4d41f7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 13:23:25 +0100 Subject: [PATCH 06/28] script --- examples/smac/smac_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/smac/smac_test.py b/examples/smac/smac_test.py index 7164b001968..9a050bb2efd 100644 --- a/examples/smac/smac_test.py +++ b/examples/smac/smac_test.py @@ -25,6 +25,7 @@ in_keys=["observation", "available_actions"], out_keys=["action"], distribution_class=OneHotCategorical, + save_dist_params=True, ) # Test the policy From 0f30a673afbf027488659d75fbe1190b63829a9e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 13:30:11 +0100 Subject: [PATCH 07/28] script --- examples/smac/smac_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/smac/smac_test.py b/examples/smac/smac_test.py index 9a050bb2efd..19de7172ae7 100644 --- a/examples/smac/smac_test.py +++ b/examples/smac/smac_test.py @@ -31,6 +31,10 @@ # Test the policy policy_td_module(td) print(td) + print('param: ', td.get("action_dist_param_0")) + print('action: ', td.get("action")) + print('mask: ', td.get("available_actions")) + print('mask from env: ', env.env._env.get_avail_actions()) # check that an ation can be performed in the env with this env.step(td) From 8b3fe9b385d8cad869dadeb4f7621394b8b8af8b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 13:32:43 +0100 Subject: [PATCH 08/28] script --- examples/smac/policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/smac/policy.py b/examples/smac/policy.py index 24b81cd70ae..bc452d2d265 100644 --- a/examples/smac/policy.py +++ b/examples/smac/policy.py @@ -11,6 +11,6 @@ def forward(self, *inputs): if isinstance(outputs, Tensor): outputs = (outputs,) # first output is logits - outputs[0].masked_fill_(mask.expand_as(outputs[0]), -float("inf")) + outputs[0].masked_fill_(~mask.expand_as(outputs[0]), -float("inf")) return tuple(outputs) From cd60690cbb9081c91a7586270c91bd273911611c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 13:33:38 +0100 Subject: [PATCH 09/28] script --- examples/smac/policy.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/smac/policy.py b/examples/smac/policy.py index bc452d2d265..8eebb5fa697 100644 --- a/examples/smac/policy.py +++ b/examples/smac/policy.py @@ -1,3 +1,4 @@ +import torch from torch import nn, Tensor class MaskedLogitPolicy(nn.Module): @@ -11,6 +12,8 @@ def forward(self, *inputs): if isinstance(outputs, Tensor): outputs = (outputs,) # first output is logits - outputs[0].masked_fill_(~mask.expand_as(outputs[0]), -float("inf")) + outputs[0].masked_fill_( + ~mask.to(torch.bool).expand_as(outputs[0]), + -float("inf")) return tuple(outputs) From 0fc501467e380c154ee1b3cae76cd0bfb83fc1ac Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 13:36:57 +0100 Subject: [PATCH 10/28] script --- examples/smac/smac_test.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/smac/smac_test.py b/examples/smac/smac_test.py index 19de7172ae7..b434650cd47 100644 --- a/examples/smac/smac_test.py +++ b/examples/smac/smac_test.py @@ -1,7 +1,7 @@ from env import SCEnv from examples.smac.policy import MaskedLogitPolicy from torchrl.envs import TransformedEnv, ObservationNorm -from torchrl.modules import ProbabilisticTDModule, OneHotCategorical +from torchrl.modules import ProbabilisticTDModule, OneHotCategorical, QValueActor from torch import nn if __name__ == "__main__": @@ -39,3 +39,15 @@ # check that an ation can be performed in the env with this env.step(td) print(td) + + # we can also have a regular Q-Value actor + print('\n\nQValue') + policy_td_module = QValueActor( + policy_wrap, spec=None, + in_keys=["observation", "available_actions"], + out_keys=["actions"]) + td = env.reset() + policy_td_module(td) + print('action: ', td.get("action")) + env.step(td) + print('next_obs: ', td.get("next_observation")) From 143eba17b41820b051407b7b87d0ecdb32be363d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 13:42:11 +0100 Subject: [PATCH 11/28] script --- examples/smac/smac_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/smac/smac_test.py b/examples/smac/smac_test.py index b434650cd47..70ee20d39b9 100644 --- a/examples/smac/smac_test.py +++ b/examples/smac/smac_test.py @@ -45,7 +45,8 @@ policy_td_module = QValueActor( policy_wrap, spec=None, in_keys=["observation", "available_actions"], - out_keys=["actions"]) + # out_keys=["actions"] + ) td = env.reset() policy_td_module(td) print('action: ', td.get("action")) From 88ef3c31fb23aab4741929f46f482dbba0e10255 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 13:46:02 +0100 Subject: [PATCH 12/28] script --- torchrl/modules/td_module/actors.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchrl/modules/td_module/actors.py b/torchrl/modules/td_module/actors.py index 455fe85d531..f5d8fe98def 100644 --- a/torchrl/modules/td_module/actors.py +++ b/torchrl/modules/td_module/actors.py @@ -239,6 +239,8 @@ def __init__( def __call__( self, net: nn.Module, observation: torch.Tensor, values: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if isinstance(values, tuple): + values = values[0] action = self.fun_dict[self.action_space](values) chosen_action_value = (action * values).sum(-1, True) return action, values, chosen_action_value @@ -331,6 +333,8 @@ def __init__( def __call__( self, net: nn.Module, observation: torch.Tensor, values: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: + if isinstance(values, tuple): + values = values[0] action = self.fun_dict[self.action_space](values, self.support) return action, values From 80c9aae907da9a9205cf29097668577680e88902 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 14:03:21 +0100 Subject: [PATCH 13/28] collector --- examples/smac/smac_test.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/examples/smac/smac_test.py b/examples/smac/smac_test.py index 70ee20d39b9..aed41151d14 100644 --- a/examples/smac/smac_test.py +++ b/examples/smac/smac_test.py @@ -1,5 +1,6 @@ from env import SCEnv from examples.smac.policy import MaskedLogitPolicy +from torchrl.agents.helpers import sync_async_collector from torchrl.envs import TransformedEnv, ObservationNorm from torchrl.modules import ProbabilisticTDModule, OneHotCategorical, QValueActor from torch import nn @@ -52,3 +53,19 @@ print('action: ', td.get("action")) env.step(td) print('next_obs: ', td.get("next_observation")) + + # now let's collect data, see MultiaSyncDataCollector for info + print('\n\nCollector') + collector = sync_async_collector( + env_fns=lambda: SCEnv("8m"), + num_collectors=4, # 4 main processes + num_env_per_collector=8, # each of the 4 collectors has 8 processes + policy=policy_td_module, + devices=["cuda:0"]*4, # each collector will execute the policy on cuda + total_frames=1000, # we'd like to have a total of 1000 frames + max_frames_per_traj=10, # we'll reset after 10 steps + frames_per_batch=64, # each batch should have 64 frames + init_random_frames=0, # we won't execute random actions + ) + for td in collector: + print(td) From 0b4553502d69958075b2db635918cd79d10b0890 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 14:04:26 +0100 Subject: [PATCH 14/28] collector --- examples/smac/smac_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/smac/smac_test.py b/examples/smac/smac_test.py index aed41151d14..99a5996347a 100644 --- a/examples/smac/smac_test.py +++ b/examples/smac/smac_test.py @@ -58,6 +58,7 @@ print('\n\nCollector') collector = sync_async_collector( env_fns=lambda: SCEnv("8m"), + env_kwargs=None, num_collectors=4, # 4 main processes num_env_per_collector=8, # each of the 4 collectors has 8 processes policy=policy_td_module, From 41542bf853a58dc034be29f70abddc9661dc218b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 14:07:49 +0100 Subject: [PATCH 15/28] collector --- examples/smac/smac_test.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/smac/smac_test.py b/examples/smac/smac_test.py index 99a5996347a..cce44ec9e74 100644 --- a/examples/smac/smac_test.py +++ b/examples/smac/smac_test.py @@ -1,6 +1,7 @@ from env import SCEnv from examples.smac.policy import MaskedLogitPolicy from torchrl.agents.helpers import sync_async_collector +from torchrl.data import TensorDictPrioritizedReplayBuffer from torchrl.envs import TransformedEnv, ObservationNorm from torchrl.modules import ProbabilisticTDModule, OneHotCategorical, QValueActor from torch import nn @@ -68,5 +69,9 @@ frames_per_batch=64, # each batch should have 64 frames init_random_frames=0, # we won't execute random actions ) + print('replay buffer') + rb = TensorDictPrioritizedReplayBuffer(size=100, alpha=0.7, beta=1.1) for td in collector: - print(td) + rb.extend(td.view(-1)) + + print('rb sample: ', rb.sample(2)) From 9294192169cd8ce02535ab33e82f0298aab0f0dc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 14:12:24 +0100 Subject: [PATCH 16/28] collector --- examples/smac/smac_test.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/smac/smac_test.py b/examples/smac/smac_test.py index cce44ec9e74..4c017f86ef9 100644 --- a/examples/smac/smac_test.py +++ b/examples/smac/smac_test.py @@ -72,6 +72,11 @@ print('replay buffer') rb = TensorDictPrioritizedReplayBuffer(size=100, alpha=0.7, beta=1.1) for td in collector: - rb.extend(td.view(-1)) + print(f'collected tensordict has shape [Batch x Time]={td.shape}') + # rb.extend(td.view(-1)) # we split each action + rb.extend(td) # we split each trajectory + + collector.update_policy_weights_() # if you have updated the local + # policy (on cpu) you may want to sync the collectors' policies to it + print('rb sample: ', rb.sample(2)) - print('rb sample: ', rb.sample(2)) From 5f40cefa1343dfe46b5efa777381bd5f3af614a2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 14:31:32 +0100 Subject: [PATCH 17/28] collector --- examples/smac/env.py | 2 +- examples/smac/smac_test.py | 2 +- torchrl/data/tensordict/tensordict.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/smac/env.py b/examples/smac/env.py index 142d2c21643..1c8e514e598 100644 --- a/examples/smac/env.py +++ b/examples/smac/env.py @@ -120,7 +120,7 @@ def _output_transform(self, step_result): def _reset( self, tensor_dict: Optional[_TensorDict] = None, **kwargs ) -> _TensorDict: - obs = self._env.get_obs() + obs = np.ndarray(self._env.get_obs()) tensor_dict_out = TensorDict( source=self._read_obs(obs), batch_size=self.batch_size diff --git a/examples/smac/smac_test.py b/examples/smac/smac_test.py index 4c017f86ef9..a6a5f267f04 100644 --- a/examples/smac/smac_test.py +++ b/examples/smac/smac_test.py @@ -74,7 +74,7 @@ for td in collector: print(f'collected tensordict has shape [Batch x Time]={td.shape}') # rb.extend(td.view(-1)) # we split each action - rb.extend(td) # we split each trajectory + rb.extend(td.unbind(0)) # we split each trajectory collector.update_policy_weights_() # if you have updated the local # policy (on cpu) you may want to sync the collectors' policies to it diff --git a/torchrl/data/tensordict/tensordict.py b/torchrl/data/tensordict/tensordict.py index 1bc91520689..a8ccc765c59 100644 --- a/torchrl/data/tensordict/tensordict.py +++ b/torchrl/data/tensordict/tensordict.py @@ -2369,8 +2369,8 @@ def set( if self.batch_size != tensor.shape[: self.batch_dims]: raise RuntimeError( "Setting tensor to tensordict failed because the shapes " - "mismatch: got tensor.shape = {tensor.shape} and " - "tensordict.batch_size={self.batch_size}" + f"mismatch: got tensor.shape = {tensor.shape} and " + f"tensordict.batch_size={self.batch_size}" ) proc_tensor = self._process_tensor( tensor, check_device=False, check_tensor_shape=False @@ -2384,8 +2384,8 @@ def set_(self, key: str, tensor: COMPATIBLE_TYPES) -> _TensorDict: if self.batch_size != tensor.shape[: self.batch_dims]: raise RuntimeError( "Setting tensor to tensordict failed because the shapes " - "mismatch: got tensor.shape = {tensor.shape} and " - "tensordict.batch_size={self.batch_size}" + f"mismatch: got tensor.shape = {tensor.shape} and " + f"tensordict.batch_size={self.batch_size}" ) if key not in self.valid_keys: raise KeyError( From 3502bfd044b5cab81b9b1dc7c1690c73122c2447 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 14:34:57 +0100 Subject: [PATCH 18/28] bf --- examples/smac/env.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/smac/env.py b/examples/smac/env.py index 1c8e514e598..6833dd29756 100644 --- a/examples/smac/env.py +++ b/examples/smac/env.py @@ -120,10 +120,10 @@ def _output_transform(self, step_result): def _reset( self, tensor_dict: Optional[_TensorDict] = None, **kwargs ) -> _TensorDict: - obs = np.ndarray(self._env.get_obs()) + obs = self._env.get_obs() tensor_dict_out = TensorDict( - source=self._read_obs(obs), batch_size=self.batch_size + source=self._read_obs(np.array(obs)), batch_size=self.batch_size ) self._is_done = torch.zeros(1, dtype=torch.bool) tensor_dict_out.set("done", self._is_done) @@ -157,7 +157,7 @@ def _step(self, tensor_dict: _TensorDict) -> _TensorDict: if done: break - obs_dict = self._read_obs(obs) + obs_dict = self._read_obs(np.array(obs)) if reward is None: reward = np.nan From 915b39a7c2319f08cbae7cbcc2e52323e36883c4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 14:43:03 +0100 Subject: [PATCH 19/28] bf --- torchrl/data/replay_buffers/replay_buffers.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index dc82678b1e3..f65265ea127 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -593,19 +593,22 @@ def collate_fn(x): self.priority_key = priority_key def _get_priority(self, tensor_dict: _TensorDict) -> torch.Tensor: - if tensor_dict.batch_dims: - raise RuntimeError( - "expected void batch_size for input tensor_dict in " - "rb._get_priority()" - ) try: - priority = tensor_dict.get(self.priority_key).item() - except ValueError: - raise ValueError( - f"Found a priority key of size" - f" {tensor_dict.get(self.priority_key).shape} but expected " - f"scalar value" - ) + if tensor_dict.batch_dims: + priority = tensor_dict.get(self.priority_key).mean().item() + # raise RuntimeError( + # "expected void batch_size for input tensor_dict in " + # "rb._get_priority()" + # ) + else: + try: + priority = tensor_dict.get(self.priority_key).item() + except ValueError: + raise ValueError( + f"Found a priority key of size" + f" {tensor_dict.get(self.priority_key).shape} but expected " + f"scalar value" + ) except KeyError: priority = self._default_priority return priority From c968ef5cb5c172595d63a369daadce5b6b61d081 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 30 Mar 2022 14:50:02 +0100 Subject: [PATCH 20/28] bf --- examples/smac/smac_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/smac/smac_test.py b/examples/smac/smac_test.py index a6a5f267f04..0ff675492d5 100644 --- a/examples/smac/smac_test.py +++ b/examples/smac/smac_test.py @@ -73,8 +73,8 @@ rb = TensorDictPrioritizedReplayBuffer(size=100, alpha=0.7, beta=1.1) for td in collector: print(f'collected tensordict has shape [Batch x Time]={td.shape}') - # rb.extend(td.view(-1)) # we split each action - rb.extend(td.unbind(0)) # we split each trajectory + rb.extend(td.view(-1)) # we split each action + # rb.extend(td.unbind(0)) # we split each trajectory -- WIP collector.update_policy_weights_() # if you have updated the local # policy (on cpu) you may want to sync the collectors' policies to it From 782ba2a32773dc861fd169b8b2ecf7ef210fa2c8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 2 Apr 2022 19:09:42 +0100 Subject: [PATCH 21/28] Update environment.yml --- .circleci/unittest/linux/scripts/environment.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.circleci/unittest/linux/scripts/environment.yml b/.circleci/unittest/linux/scripts/environment.yml index b64ed32f221..0af0ab2ca28 100644 --- a/.circleci/unittest/linux/scripts/environment.yml +++ b/.circleci/unittest/linux/scripts/environment.yml @@ -3,7 +3,6 @@ channels: - defaults dependencies: - pip - - ninja - cmake >= 3.18 - pip: - hypothesis From 8dcf10b043614cf12ca6662d3bc76547ea6b8ce3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 13 Apr 2022 16:12:34 +0100 Subject: [PATCH 22/28] updating & lint --- examples/smac/env.py | 51 +++++++++++++++++++++++++------------- examples/smac/policy.py | 6 ++--- examples/smac/smac_test.py | 30 +++++++++++----------- 3 files changed, 52 insertions(+), 35 deletions(-) diff --git a/examples/smac/env.py b/examples/smac/env.py index 6833dd29756..2f34fee58a6 100644 --- a/examples/smac/env.py +++ b/examples/smac/env.py @@ -4,11 +4,17 @@ import numpy as np import torch from smac.env import StarCraft2Env - -from torchrl.data import TensorDict, NdUnboundedContinuousTensorSpec, \ - UnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec -from torchrl.data.tensor_specs import _default_dtype_and_device, DiscreteBox, \ - DEVICE_TYPING +from torchrl.data import ( + TensorDict, + NdUnboundedContinuousTensorSpec, + UnboundedContinuousTensorSpec, + OneHotDiscreteTensorSpec, +) +from torchrl.data.tensor_specs import ( + _default_dtype_and_device, + DiscreteBox, + DEVICE_TYPING, +) from torchrl.data.tensordict.tensordict import _TensorDict from torchrl.envs.common import GymLikeEnv @@ -29,9 +35,15 @@ def __init__( n, ) self.d = d - shape = torch.Size((d, space.n,)) - super(OneHotDiscreteTensorSpec, self).__init__(shape, space, device, - dtype, "discrete") + shape = torch.Size( + ( + d, + space.n, + ) + ) + super(OneHotDiscreteTensorSpec, self).__init__( + shape, space, device, dtype, "discrete" + ) def rand(self, shape=torch.Size([])) -> torch.Tensor: return torch.nn.functional.gumbel_softmax( @@ -59,9 +71,15 @@ def __init__( n, ) self.d = d - shape = torch.Size((d, space.n,)) - super(OneHotDiscreteTensorSpec, self).__init__(shape, space, device, - dtype, "discrete") + shape = torch.Size( + ( + d, + space.n, + ) + ) + super(OneHotDiscreteTensorSpec, self).__init__( + shape, space, device, dtype, "discrete" + ) def to(self, dest): out = super().to(dest) @@ -88,22 +106,21 @@ class SCEnv(GymLikeEnv): @property def observation_spec(self): info = self._env.get_env_info() - dim = (info['n_agents'], info['obs_shape']) + dim = (info["n_agents"], info["obs_shape"]) return NdUnboundedContinuousTensorSpec(dim) @property def action_spec(self): - info = self._env.get_env_info() + # info = self._env.get_env_info() return CustomNdOneHotDiscreteTensorSpec( - torch.tensor(self._env.get_avail_actions())) + torch.tensor(self._env.get_avail_actions()) + ) @property def reward_spec(self): return UnboundedContinuousTensorSpec() - def _build_env( - self, map_name: str, taskname=None, **kwargs - ) -> None: + def _build_env(self, map_name: str, taskname=None, **kwargs) -> None: if taskname: raise RuntimeError diff --git a/examples/smac/policy.py b/examples/smac/policy.py index 8eebb5fa697..8543c5f244d 100644 --- a/examples/smac/policy.py +++ b/examples/smac/policy.py @@ -1,6 +1,7 @@ import torch from torch import nn, Tensor + class MaskedLogitPolicy(nn.Module): def __init__(self, policy_module): super().__init__() @@ -13,7 +14,6 @@ def forward(self, *inputs): outputs = (outputs,) # first output is logits outputs[0].masked_fill_( - ~mask.to(torch.bool).expand_as(outputs[0]), - -float("inf")) + ~mask.to(torch.bool).expand_as(outputs[0]), -float("inf") + ) return tuple(outputs) - diff --git a/examples/smac/smac_test.py b/examples/smac/smac_test.py index 0ff675492d5..b3f00ee0f1b 100644 --- a/examples/smac/smac_test.py +++ b/examples/smac/smac_test.py @@ -1,10 +1,10 @@ from env import SCEnv from examples.smac.policy import MaskedLogitPolicy +from torch import nn from torchrl.agents.helpers import sync_async_collector from torchrl.data import TensorDictPrioritizedReplayBuffer from torchrl.envs import TransformedEnv, ObservationNorm from torchrl.modules import ProbabilisticTDModule, OneHotCategorical, QValueActor -from torch import nn if __name__ == "__main__": # create an env @@ -33,50 +33,50 @@ # Test the policy policy_td_module(td) print(td) - print('param: ', td.get("action_dist_param_0")) - print('action: ', td.get("action")) - print('mask: ', td.get("available_actions")) - print('mask from env: ', env.env._env.get_avail_actions()) + print("param: ", td.get("action_dist_param_0")) + print("action: ", td.get("action")) + print("mask: ", td.get("available_actions")) + print("mask from env: ", env.env._env.get_avail_actions()) # check that an ation can be performed in the env with this env.step(td) print(td) # we can also have a regular Q-Value actor - print('\n\nQValue') + print("\n\nQValue") policy_td_module = QValueActor( - policy_wrap, spec=None, + policy_wrap, + spec=None, in_keys=["observation", "available_actions"], # out_keys=["actions"] ) td = env.reset() policy_td_module(td) - print('action: ', td.get("action")) + print("action: ", td.get("action")) env.step(td) - print('next_obs: ', td.get("next_observation")) + print("next_obs: ", td.get("next_observation")) # now let's collect data, see MultiaSyncDataCollector for info - print('\n\nCollector') + print("\n\nCollector") collector = sync_async_collector( env_fns=lambda: SCEnv("8m"), env_kwargs=None, num_collectors=4, # 4 main processes num_env_per_collector=8, # each of the 4 collectors has 8 processes policy=policy_td_module, - devices=["cuda:0"]*4, # each collector will execute the policy on cuda + devices=["cuda:0"] * 4, # each collector will execute the policy on cuda total_frames=1000, # we'd like to have a total of 1000 frames max_frames_per_traj=10, # we'll reset after 10 steps frames_per_batch=64, # each batch should have 64 frames init_random_frames=0, # we won't execute random actions ) - print('replay buffer') + print("replay buffer") rb = TensorDictPrioritizedReplayBuffer(size=100, alpha=0.7, beta=1.1) for td in collector: - print(f'collected tensordict has shape [Batch x Time]={td.shape}') + print(f"collected tensordict has shape [Batch x Time]={td.shape}") rb.extend(td.view(-1)) # we split each action # rb.extend(td.unbind(0)) # we split each trajectory -- WIP collector.update_policy_weights_() # if you have updated the local # policy (on cpu) you may want to sync the collectors' policies to it - print('rb sample: ', rb.sample(2)) - + print("rb sample: ", rb.sample(2)) From 5ee762cd0dcbd721b669b6671cc5ec4e222df8c0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 22 Apr 2022 15:02:56 +0100 Subject: [PATCH 23/28] integrate SMAC bits in code base --- examples/smac/env.py | 80 +-------------------------- examples/smac/policy.py | 17 ------ test/test_modules.py | 57 +++++++++++++++++-- test/test_postprocs.py | 5 +- test/test_tensor_spec.py | 29 +++++++++- torchrl/data/tensor_specs.py | 95 ++++++++++++++++++++++++++++++++ torchrl/modules/models/models.py | 52 ++++++++++++++++- 7 files changed, 232 insertions(+), 103 deletions(-) diff --git a/examples/smac/env.py b/examples/smac/env.py index 2f34fee58a6..d744cfe7ec7 100644 --- a/examples/smac/env.py +++ b/examples/smac/env.py @@ -13,91 +13,13 @@ from torchrl.data.tensor_specs import ( _default_dtype_and_device, DiscreteBox, + CustomNdOneHotDiscreteTensorSpec, DEVICE_TYPING, ) from torchrl.data.tensordict.tensordict import _TensorDict from torchrl.envs.common import GymLikeEnv -@dataclass(repr=False) -class NdOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): - def __init__( - self, - n: int, - d: int, - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[Union[str, torch.dtype]] = torch.long, - use_register: bool = False, - ): - dtype, device = _default_dtype_and_device(dtype, device) - self.use_register = use_register - space = DiscreteBox( - n, - ) - self.d = d - shape = torch.Size( - ( - d, - space.n, - ) - ) - super(OneHotDiscreteTensorSpec, self).__init__( - shape, space, device, dtype, "discrete" - ) - - def rand(self, shape=torch.Size([])) -> torch.Tensor: - return torch.nn.functional.gumbel_softmax( - torch.rand(*shape, self.d, self.space.n, device=self.device), - hard=True, - dim=-1, - ).to(torch.long) - - -@dataclass(repr=False) -class CustomNdOneHotDiscreteTensorSpec(NdOneHotDiscreteTensorSpec): - def __init__( - self, - mask: torch.Tensor, - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[Union[str, torch.dtype]] = torch.long, - use_register: bool = False, - ): - self.mask = mask - *_, d, n = mask.shape - - dtype, device = _default_dtype_and_device(dtype, device) - self.use_register = use_register - space = DiscreteBox( - n, - ) - self.d = d - shape = torch.Size( - ( - d, - space.n, - ) - ) - super(OneHotDiscreteTensorSpec, self).__init__( - shape, space, device, dtype, "discrete" - ) - - def to(self, dest): - out = super().to(dest) - out.mask = self.mask.to(dest) - return out - - def rand(self, shape=torch.Size([])) -> torch.Tensor: - mask = self.mask.to(torch.float) - return torch.nn.functional.gumbel_softmax( - mask.log(), - hard=True, - dim=-1, - ).to(torch.long) - - def is_in(self, value): - return ((self.mask - value) >= 0).all() - - class SCEnv(GymLikeEnv): available_envs = ["8m"] # TODO: add to parent class diff --git a/examples/smac/policy.py b/examples/smac/policy.py index 8543c5f244d..28b4874e04a 100644 --- a/examples/smac/policy.py +++ b/examples/smac/policy.py @@ -1,19 +1,2 @@ import torch from torch import nn, Tensor - - -class MaskedLogitPolicy(nn.Module): - def __init__(self, policy_module): - super().__init__() - self.policy_module = policy_module - - def forward(self, *inputs): - *inputs, mask = inputs - outputs = self.policy_module(*inputs) - if isinstance(outputs, Tensor): - outputs = (outputs,) - # first output is logits - outputs[0].masked_fill_( - ~mask.to(torch.bool).expand_as(outputs[0]), -float("inf") - ) - return tuple(outputs) diff --git a/test/test_modules.py b/test/test_modules.py index 791654e1bd3..e3a9ba3cb82 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -3,14 +3,19 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import argparse from numbers import Number +import functorch import pytest import torch from _utils_internal import get_available_devices -from torch import nn +from torch import nn, distributions as D from torchrl.data import TensorDict -from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec +from torchrl.data.tensor_specs import ( + OneHotDiscreteTensorSpec, + CustomNdOneHotDiscreteTensorSpec, +) from torchrl.modules import ( QValueActor, ActorValueOperator, @@ -18,7 +23,8 @@ ValueOperator, ProbabilisticActor, ) -from torchrl.modules.models import NoisyLinear, MLP, NoisyLazyLinear +from torchrl.modules.distributions import OneHotCategorical +from torchrl.modules.models import NoisyLinear, MLP, NoisyLazyLinear, MaskedLogitPolicy @pytest.mark.parametrize("in_features", [3, 10, None]) @@ -175,5 +181,48 @@ def test_actorcritic(device): ) == len(policy_params) +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("functional", [True, False]) +@pytest.mark.parametrize( + "mask", [torch.tensor([False, True, False, True, False]), "random"] +) +def test_maskedlogit(device, functional, mask): + batch = 10 + torch.manual_seed(0) + if isinstance(mask, str): + random_mask = True + mask = torch.zeros(batch, 5, dtype=torch.bool, device=device).bernoulli_() + else: + random_mask = False + mask = mask.to(device) + policy_net = nn.Linear(3, 5, bias=False).to(device) # model that returns logits + + policy_net_wrapped = MaskedLogitPolicy(policy_net) + if functional: + policy_net_wrapped, params = functorch.make_functional(policy_net_wrapped) + observation = torch.randn(batch, 3, device=device) + if functional: + logits_masked = policy_net_wrapped(params, observation, mask) + else: + logits_masked = policy_net_wrapped(observation, mask) + c = D.Categorical(logits=logits_masked) + samples = c.sample((1000,)) + samples_uniques = samples.unique() + if random_mask: + mask_expand = mask.expand(1000, batch, 5) + assert mask_expand.gather(-1, samples.unsqueeze(-1)).all() + else: + assert ((samples_uniques == 1) | (samples_uniques == 3)).all() + + # test synergy with CustomNdOneHotDiscreteTensorSpec + spec = CustomNdOneHotDiscreteTensorSpec(mask=mask) + c = OneHotCategorical(logits=logits_masked) + assert spec.is_in(c.sample((1000,))) + c = OneHotCategorical(logits=torch.randn_like(logits_masked)) + with pytest.raises(AssertionError): + assert spec.is_in(c.sample((1000,))) + + if __name__ == "__main__": - pytest.main([__file__]) + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_postprocs.py b/test/test_postprocs.py index 913a00418e7..98d95b9a0e1 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import argparse + import pytest import torch from torchrl.collectors.utils import split_trajectories @@ -164,4 +166,5 @@ def test_splits(self, num_workers, traj_len): if __name__ == "__main__": - pytest.main([__file__, "--capture", "no"]) + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index a7ce40fe045..8eb796abe58 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import argparse + import numpy as np import pytest import torch @@ -15,6 +17,7 @@ BoundedTensorSpec, UnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, + CustomNdOneHotDiscreteTensorSpec, ) @@ -226,6 +229,29 @@ def test_mult_onehot(shape, ns): assert (ts.encode(np_r) == r).all() +@pytest.mark.parametrize("n", range(10, 12)) +@pytest.mark.parametrize("shape", [torch.Size([]), torch.Size([10])]) +def test_custom_ndonehot(n, shape): + torch.manual_seed(0) + np.random.seed(0) + + with pytest.raises(RuntimeError): + mask = torch.zeros(*shape, n).bernoulli_() + ts = CustomNdOneHotDiscreteTensorSpec(mask) + mask = torch.zeros(*shape, n, dtype=torch.bool).bernoulli_() + ts = CustomNdOneHotDiscreteTensorSpec(mask) + + for _ in range(100): + r = ts.rand([10]) + assert r.shape == torch.Size([10, *shape, n]) + assert ts.is_in(r), r + assert ((r == 0) | (r == 1)).all() + r_numpy = r.argmax(-1).numpy() + assert (ts.encode(r_numpy) == r).all() + assert (ts.encode(ts.to_numpy(r)) == r).all() + assert (r.sum(-1) == 1).all() + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) @pytest.mark.parametrize( "shape", @@ -260,4 +286,5 @@ def test_composite(shape, dtype): if __name__ == "__main__": - pytest.main([__file__]) + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 76e626ee5c7..bad5a5e5dbd 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -33,6 +33,8 @@ "NdUnboundedContinuousTensorSpec", "BinaryDiscreteTensorSpec", "MultOneHotDiscreteTensorSpec", + "NdOneHotDiscreteTensorSpec", + "CustomNdOneHotDiscreteTensorSpec", "CompositeSpec", ] @@ -780,6 +782,99 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return torch.cat([super()._project(_val) for _val in vals], -1) +@dataclass(repr=False) +class NdOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): + """An N-dimensional One hot discrete tensor spec data class""" + + def __init__( + self, + n: int, + *shape: int, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[str, torch.dtype]] = torch.long, + use_register: bool = False, + ): + dtype, device = _default_dtype_and_device(dtype, device) + self.use_register = use_register + space = DiscreteBox( + n, + ) + self.shape = shape + total_shape = torch.Size( + ( + *shape, + space.n, + ) + ) + super(OneHotDiscreteTensorSpec, self).__init__( + total_shape, space, device, dtype, "discrete" + ) + + def rand(self, shape=torch.Size([])) -> torch.Tensor: + return torch.nn.functional.gumbel_softmax( + torch.rand(*shape, self.d, self.space.n, device=self.device), + hard=True, + dim=-1, + ).to(torch.long) + + +@dataclass(repr=False) +class CustomNdOneHotDiscreteTensorSpec(NdOneHotDiscreteTensorSpec): + """A masked N-dimensional One-Hot discrete tensor spec data-class + + The aim of this class is to check / project or document a discrete space + when it varies from environment to environment, or from step to step in the + same environment. + + """ + + def __init__( + self, + mask: torch.Tensor, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[str, torch.dtype]] = torch.long, + use_register: bool = False, + ): + if mask.dtype is not torch.bool: + raise RuntimeError( + f"Expected a mask with dtype torch.bool but got {mask.dtype}" + ) + if (mask.sum(-1) == 0).any(): + raise RuntimeError("Got an empty mask for some dimension.") + self.mask = mask + *shape, n = mask.shape + + dtype, device = _default_dtype_and_device(dtype, device) + self.use_register = use_register + space = DiscreteBox( + n, + ) + self.shape = shape + total_shape = torch.Size( + ( + *shape, + space.n, + ) + ) + super(OneHotDiscreteTensorSpec, self).__init__( + total_shape, space, device, dtype, "discrete" + ) + + def to(self, dest): + out = super().to(dest) + out.mask = self.mask.to(dest) + return out + + def rand(self, shape=torch.Size([])) -> torch.Tensor: + mask = self.mask.expand(*shape, *self.mask.shape) + r = torch.rand(mask.shape, device=mask.device).masked_fill_(~mask, 0.0) + return (r == r.max(-1, keepdim=True)[0]).to(torch.long) + + def is_in(self, value): + congruent = self.mask & value.to(torch.bool) + return (congruent.sum(-1) == 1).all() + + class CompositeSpec(TensorSpec): """ A composition of TensorSpecs. diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 21ad0e49e8e..ce664007b96 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -8,7 +8,7 @@ import numpy as np import torch -from torch import nn +from torch import nn, Tensor from torch.nn import functional as F from torchrl.modules.models.utils import ( @@ -28,6 +28,7 @@ "DdpgMlpActor", "DdpgMlpQNet", "LSTMNet", + "MaskedLogitPolicy", ] @@ -894,6 +895,55 @@ def forward(self, observation: torch.Tensor, action: torch.Tensor) -> torch.Tens return value +class MaskedLogitPolicy(nn.Module): + """Wrapper for logit output masking. + + This module takes as input a logit tensor and a mask, and returns a similar + logit tensor where invalid values have been masked out. + + This is aimed to be used in environments where the space of valid actions + is dynamic, or in settings where multiple tasks with different discrete + action space are trained using the same policy model. + + Examples: + >>> policy_net = nn.Linear(3, 5, bias=False) # model that returns logits + >>> mask = torch.tensor([False, True, False, True, False]) + >>> policy_net_wrapped = MaskedLogitPolicy(policy_net) + >>> observation = torch.zeros(2, 3) + >>> logits_masked = policy_net_wrapped(observation, mask) + >>> print(logits_masked) + (tensor([[-inf, 0., -inf, 0., -inf], + [-inf, 0., -inf, 0., -inf]], grad_fn=),) + + This class also supports functional modules: + >>> import functorch + >>> policy_net = nn.Linear(3, 5, bias=False) # model that returns logits + >>> mask = torch.tensor([False, True, False, True, False]) + >>> policy_net_wrapped = MaskedLogitPolicy(policy_net) + >>> fpolicy_net_wrapped, params = functorch.make_functional(policy_net_wrapped) + >>> observation = torch.zeros(2, 3) + >>> logits_masked = fpolicy_net_wrapped(params, observation, mask) + >>> print(logits_masked) + (tensor([[-inf, 0., -inf, 0., -inf], + [-inf, 0., -inf, 0., -inf]], grad_fn=),) + """ + + def __init__(self, policy_module): + super().__init__() + self.policy_module = policy_module + + def forward(self, *inputs, **kwargs): + *inputs, mask = inputs + outputs = self.policy_module(*inputs, **kwargs) + if isinstance(outputs, Tensor): + _outputs = outputs + else: + _outputs = outputs[0] + # first output is logits + _outputs.masked_fill_(~mask.to(torch.bool).expand_as(_outputs), -float("inf")) + return outputs + + class LSTMNet(nn.Module): """ An embedder for an LSTM followed by an MLP. From 6f56ea7e30a755ca7815c2ad9808d1024e78b95e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 22 Apr 2022 16:28:44 +0100 Subject: [PATCH 24/28] lint --- examples/smac/env.py | 7 +------ examples/smac/policy.py | 2 -- examples/smac/smac_test.py | 8 ++++++-- 3 files changed, 7 insertions(+), 10 deletions(-) delete mode 100644 examples/smac/policy.py diff --git a/examples/smac/env.py b/examples/smac/env.py index d744cfe7ec7..684746f2971 100644 --- a/examples/smac/env.py +++ b/examples/smac/env.py @@ -1,5 +1,4 @@ -from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional import numpy as np import torch @@ -8,13 +7,9 @@ TensorDict, NdUnboundedContinuousTensorSpec, UnboundedContinuousTensorSpec, - OneHotDiscreteTensorSpec, ) from torchrl.data.tensor_specs import ( - _default_dtype_and_device, - DiscreteBox, CustomNdOneHotDiscreteTensorSpec, - DEVICE_TYPING, ) from torchrl.data.tensordict.tensordict import _TensorDict from torchrl.envs.common import GymLikeEnv diff --git a/examples/smac/policy.py b/examples/smac/policy.py deleted file mode 100644 index 28b4874e04a..00000000000 --- a/examples/smac/policy.py +++ /dev/null @@ -1,2 +0,0 @@ -import torch -from torch import nn, Tensor diff --git a/examples/smac/smac_test.py b/examples/smac/smac_test.py index b3f00ee0f1b..adf67f6aa18 100644 --- a/examples/smac/smac_test.py +++ b/examples/smac/smac_test.py @@ -1,10 +1,14 @@ from env import SCEnv -from examples.smac.policy import MaskedLogitPolicy from torch import nn from torchrl.agents.helpers import sync_async_collector from torchrl.data import TensorDictPrioritizedReplayBuffer from torchrl.envs import TransformedEnv, ObservationNorm -from torchrl.modules import ProbabilisticTDModule, OneHotCategorical, QValueActor +from torchrl.modules import ( + ProbabilisticTDModule, + OneHotCategorical, + QValueActor, + MaskedLogitPolicy, +) if __name__ == "__main__": # create an env From 33ba39f5a448149bf9ad1d6313b60edb144e98d3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 22 Apr 2022 16:31:10 +0100 Subject: [PATCH 25/28] lint --- examples/smac/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/smac/env.py b/examples/smac/env.py index 684746f2971..4ee5cc5564e 100644 --- a/examples/smac/env.py +++ b/examples/smac/env.py @@ -49,7 +49,7 @@ def _output_transform(self, step_result): reward, done, *other = step_result obs = self._env.get_obs() available_actions = self._env.get_avail_actions() - return obs, reward, done, available_actions, *other + return (obs, reward, done, available_actions, *other) def _reset( self, tensor_dict: Optional[_TensorDict] = None, **kwargs From 6ce41cb36e97916decfd713366f17f5fcd034649 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 22 Apr 2022 16:51:31 +0100 Subject: [PATCH 26/28] Add extra_require to setup.py --- README.md | 7 +++++++ setup.py | 11 ++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 47287f3e8bd..c7fad5a2d78 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,13 @@ pip install gym gym[accept-rom-license] pygame gym_retro pip install pytest ``` +Alternatively, extra dependencies can be installed using +``` +pip install ".[atari,dm_control,gym_continuous,rendering,tests,utils]" +``` +or a selection of these. + + **Troubleshooting** If a `ModuleNotFoundError: No module named ‘torchrl._torchrl` errors occurs, it means that the C++ extensions were not installed or not found. diff --git a/setup.py b/setup.py index cb43b9ca6cc..e6294a088ec 100644 --- a/setup.py +++ b/setup.py @@ -10,9 +10,10 @@ import subprocess from pathlib import Path -from build_tools import setup_helpers from setuptools import setup, find_packages +from build_tools import setup_helpers + def _get_pytorch_version(): if "PYTORCH_VERSION" in os.environ: @@ -78,6 +79,14 @@ def _main(): "clean": clean, }, install_requires=[pytorch_package_dep, "numpy"], + extras_require={ + "atari": ["gym", "atari-py", "ale-py", "gym[accept-rom-license]", "pygame"], + "dm_control": ["dm_control"], + "gym_continuous": ["mujoco-py", "mujoco"], + "rendering": ["moviepy"], + "tests": ["pytest"], + "utils": ["tqdm", "configargparse"], + }, ) From 4ac14354ad267252d9fd6f6ab204c670d1f0fe00 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 22 Apr 2022 17:05:37 +0100 Subject: [PATCH 27/28] extra deps --- README.md | 4 +- setup.py | 1 + torchrl/envs/libs/smac.py | 118 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 torchrl/envs/libs/smac.py diff --git a/README.md b/README.md index c7fad5a2d78..5d2936dd6dc 100644 --- a/README.md +++ b/README.md @@ -111,10 +111,12 @@ pip install pytest Alternatively, extra dependencies can be installed using ``` -pip install ".[atari,dm_control,gym_continuous,rendering,tests,utils]" +pip install ".[atari,dm_control,gym_continuous,rendering,tests,utils,smac]" ``` or a selection of these. +**N.B.**: SMAC (Starcraft Multiagent Contest) requires you to install Starcraft II. +Support on non-linux machines is limited. Please refer to the [original repo](https://github.com/oxwhirl/smac) for more information. **Troubleshooting** diff --git a/setup.py b/setup.py index e6294a088ec..4983ff36c4d 100644 --- a/setup.py +++ b/setup.py @@ -86,6 +86,7 @@ def _main(): "rendering": ["moviepy"], "tests": ["pytest"], "utils": ["tqdm", "configargparse"], + "smac": ["smac @ git+https://github.com/oxwhirl/smac.git"], }, ) diff --git a/torchrl/envs/libs/smac.py b/torchrl/envs/libs/smac.py new file mode 100644 index 00000000000..7279a7714b5 --- /dev/null +++ b/torchrl/envs/libs/smac.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional + +import numpy as np +import torch +from smac.env import StarCraft2Env + +from torchrl.data import ( + TensorDict, + NdUnboundedContinuousTensorSpec, + UnboundedContinuousTensorSpec, +) +from torchrl.data.tensor_specs import ( + CustomNdOneHotDiscreteTensorSpec, +) +from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.envs.common import GymLikeEnv + + +class SCEnv(GymLikeEnv): + available_envs = ["8m"] + # TODO: add to parent class + supplementary_keys = ["available_actions"] + + @property + def observation_spec(self): + info = self._env.get_env_info() + dim = (info["n_agents"], info["obs_shape"]) + return NdUnboundedContinuousTensorSpec(dim) + + @property + def action_spec(self): + # info = self._env.get_env_info() + return CustomNdOneHotDiscreteTensorSpec( + torch.tensor(self._env.get_avail_actions()) + ) + + @property + def reward_spec(self): + return UnboundedContinuousTensorSpec() + + def _build_env(self, map_name: str, taskname=None, **kwargs) -> None: + if taskname: + raise RuntimeError + + env = StarCraft2Env(map_name=map_name) + self._env = env + return env + + def _output_transform(self, step_result): + reward, done, *other = step_result + obs = self._env.get_obs() + available_actions = self._env.get_avail_actions() + return (obs, reward, done, available_actions, *other) + + def _reset( + self, tensor_dict: Optional[_TensorDict] = None, **kwargs + ) -> _TensorDict: + obs = self._env.get_obs() + + tensor_dict_out = TensorDict( + source=self._read_obs(np.array(obs)), batch_size=self.batch_size + ) + self._is_done = torch.zeros(1, dtype=torch.bool) + tensor_dict_out.set("done", self._is_done) + available_actions = self._env.get_avail_actions() + tensor_dict_out.set("available_actions", available_actions) + return tensor_dict_out + + def _init_env(self, seed=None): + self._env.reset() + if seed is not None: + self.set_seed(seed) + + # TODO: check that actions match avail + def _action_transform(self, action): + action_np = self.action_spec.to_numpy(action) + return action_np + + # TODO: move to GymLike + def _step(self, tensor_dict: _TensorDict) -> _TensorDict: + action = tensor_dict.get("action") + action_np = self._action_transform(action) + + reward = 0.0 + for _ in range(self.wrapper_frame_skip): + obs, _reward, done, *info = self._output_transform( + self._env.step(action_np) + ) + if _reward is None: + _reward = 0.0 + reward += _reward + if done: + break + + obs_dict = self._read_obs(np.array(obs)) + + if reward is None: + reward = np.nan + reward = self._to_tensor(reward, dtype=self.reward_spec.dtype) + done = self._to_tensor(done, dtype=torch.bool) + self._is_done = done + self._current_tensordict = obs_dict + + tensor_dict_out = TensorDict({}, batch_size=tensor_dict.batch_size) + for key, value in obs_dict.items(): + tensor_dict_out.set(f"next_{key}", value) + tensor_dict_out.set("reward", reward) + tensor_dict_out.set("done", done) + for k, value in zip(self.supplementary_keys, info): + tensor_dict_out.set(k, value) + + return tensor_dict_out From fe0eb8857e916861eff77197b52c3800ef2af46d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 22 Apr 2022 17:26:58 +0100 Subject: [PATCH 28/28] lint --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 4983ff36c4d..447cf3c62d3 100644 --- a/setup.py +++ b/setup.py @@ -10,9 +10,8 @@ import subprocess from pathlib import Path -from setuptools import setup, find_packages - from build_tools import setup_helpers +from setuptools import setup, find_packages def _get_pytorch_version():