diff --git a/README.md b/README.md index 47287f3e8bd..5d2936dd6dc 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,15 @@ pip install gym gym[accept-rom-license] pygame gym_retro pip install pytest ``` +Alternatively, extra dependencies can be installed using +``` +pip install ".[atari,dm_control,gym_continuous,rendering,tests,utils,smac]" +``` +or a selection of these. + +**N.B.**: SMAC (Starcraft Multiagent Contest) requires you to install Starcraft II. +Support on non-linux machines is limited. Please refer to the [original repo](https://github.com/oxwhirl/smac) for more information. + **Troubleshooting** If a `ModuleNotFoundError: No module named ‘torchrl._torchrl` errors occurs, it means that the C++ extensions were not installed or not found. diff --git a/examples/smac/env.py b/examples/smac/env.py new file mode 100644 index 00000000000..4ee5cc5564e --- /dev/null +++ b/examples/smac/env.py @@ -0,0 +1,111 @@ +from typing import Optional + +import numpy as np +import torch +from smac.env import StarCraft2Env +from torchrl.data import ( + TensorDict, + NdUnboundedContinuousTensorSpec, + UnboundedContinuousTensorSpec, +) +from torchrl.data.tensor_specs import ( + CustomNdOneHotDiscreteTensorSpec, +) +from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.envs.common import GymLikeEnv + + +class SCEnv(GymLikeEnv): + available_envs = ["8m"] + # TODO: add to parent class + supplementary_keys = ["available_actions"] + + @property + def observation_spec(self): + info = self._env.get_env_info() + dim = (info["n_agents"], info["obs_shape"]) + return NdUnboundedContinuousTensorSpec(dim) + + @property + def action_spec(self): + # info = self._env.get_env_info() + return CustomNdOneHotDiscreteTensorSpec( + torch.tensor(self._env.get_avail_actions()) + ) + + @property + def reward_spec(self): + return UnboundedContinuousTensorSpec() + + def _build_env(self, map_name: str, taskname=None, **kwargs) -> None: + if taskname: + raise RuntimeError + + env = StarCraft2Env(map_name=map_name) + self._env = env + return env + + def _output_transform(self, step_result): + reward, done, *other = step_result + obs = self._env.get_obs() + available_actions = self._env.get_avail_actions() + return (obs, reward, done, available_actions, *other) + + def _reset( + self, tensor_dict: Optional[_TensorDict] = None, **kwargs + ) -> _TensorDict: + obs = self._env.get_obs() + + tensor_dict_out = TensorDict( + source=self._read_obs(np.array(obs)), batch_size=self.batch_size + ) + self._is_done = torch.zeros(1, dtype=torch.bool) + tensor_dict_out.set("done", self._is_done) + available_actions = self._env.get_avail_actions() + tensor_dict_out.set("available_actions", available_actions) + return tensor_dict_out + + def _init_env(self, seed=None): + self._env.reset() + if seed is not None: + self.set_seed(seed) + + # TODO: check that actions match avail + def _action_transform(self, action): + action_np = self.action_spec.to_numpy(action) + return action_np + + # TODO: move to GymLike + def _step(self, tensor_dict: _TensorDict) -> _TensorDict: + action = tensor_dict.get("action") + action_np = self._action_transform(action) + + reward = 0.0 + for _ in range(self.wrapper_frame_skip): + obs, _reward, done, *info = self._output_transform( + self._env.step(action_np) + ) + if _reward is None: + _reward = 0.0 + reward += _reward + if done: + break + + obs_dict = self._read_obs(np.array(obs)) + + if reward is None: + reward = np.nan + reward = self._to_tensor(reward, dtype=self.reward_spec.dtype) + done = self._to_tensor(done, dtype=torch.bool) + self._is_done = done + self._current_tensordict = obs_dict + + tensor_dict_out = TensorDict({}, batch_size=tensor_dict.batch_size) + for key, value in obs_dict.items(): + tensor_dict_out.set(f"next_{key}", value) + tensor_dict_out.set("reward", reward) + tensor_dict_out.set("done", done) + for k, value in zip(self.supplementary_keys, info): + tensor_dict_out.set(k, value) + + return tensor_dict_out diff --git a/examples/smac/smac_test.py b/examples/smac/smac_test.py new file mode 100644 index 00000000000..adf67f6aa18 --- /dev/null +++ b/examples/smac/smac_test.py @@ -0,0 +1,86 @@ +from env import SCEnv +from torch import nn +from torchrl.agents.helpers import sync_async_collector +from torchrl.data import TensorDictPrioritizedReplayBuffer +from torchrl.envs import TransformedEnv, ObservationNorm +from torchrl.modules import ( + ProbabilisticTDModule, + OneHotCategorical, + QValueActor, + MaskedLogitPolicy, +) + +if __name__ == "__main__": + # create an env + env = SCEnv("8m") + + # reset + td = env.reset() + print("tensordict after reset: ") + print(td) + + # apply a sequence of transforms + env = TransformedEnv(env, ObservationNorm(0, 1, standard_normal=True)) + + # Get policy + policy = nn.LazyLinear(env.action_spec.shape[-1]) + policy_wrap = MaskedLogitPolicy(policy) + policy_td_module = ProbabilisticTDModule( + module=policy_wrap, + spec=None, + in_keys=["observation", "available_actions"], + out_keys=["action"], + distribution_class=OneHotCategorical, + save_dist_params=True, + ) + + # Test the policy + policy_td_module(td) + print(td) + print("param: ", td.get("action_dist_param_0")) + print("action: ", td.get("action")) + print("mask: ", td.get("available_actions")) + print("mask from env: ", env.env._env.get_avail_actions()) + + # check that an ation can be performed in the env with this + env.step(td) + print(td) + + # we can also have a regular Q-Value actor + print("\n\nQValue") + policy_td_module = QValueActor( + policy_wrap, + spec=None, + in_keys=["observation", "available_actions"], + # out_keys=["actions"] + ) + td = env.reset() + policy_td_module(td) + print("action: ", td.get("action")) + env.step(td) + print("next_obs: ", td.get("next_observation")) + + # now let's collect data, see MultiaSyncDataCollector for info + print("\n\nCollector") + collector = sync_async_collector( + env_fns=lambda: SCEnv("8m"), + env_kwargs=None, + num_collectors=4, # 4 main processes + num_env_per_collector=8, # each of the 4 collectors has 8 processes + policy=policy_td_module, + devices=["cuda:0"] * 4, # each collector will execute the policy on cuda + total_frames=1000, # we'd like to have a total of 1000 frames + max_frames_per_traj=10, # we'll reset after 10 steps + frames_per_batch=64, # each batch should have 64 frames + init_random_frames=0, # we won't execute random actions + ) + print("replay buffer") + rb = TensorDictPrioritizedReplayBuffer(size=100, alpha=0.7, beta=1.1) + for td in collector: + print(f"collected tensordict has shape [Batch x Time]={td.shape}") + rb.extend(td.view(-1)) # we split each action + # rb.extend(td.unbind(0)) # we split each trajectory -- WIP + + collector.update_policy_weights_() # if you have updated the local + # policy (on cpu) you may want to sync the collectors' policies to it + print("rb sample: ", rb.sample(2)) diff --git a/setup.py b/setup.py index cb43b9ca6cc..447cf3c62d3 100644 --- a/setup.py +++ b/setup.py @@ -78,6 +78,15 @@ def _main(): "clean": clean, }, install_requires=[pytorch_package_dep, "numpy"], + extras_require={ + "atari": ["gym", "atari-py", "ale-py", "gym[accept-rom-license]", "pygame"], + "dm_control": ["dm_control"], + "gym_continuous": ["mujoco-py", "mujoco"], + "rendering": ["moviepy"], + "tests": ["pytest"], + "utils": ["tqdm", "configargparse"], + "smac": ["smac @ git+https://github.com/oxwhirl/smac.git"], + }, ) diff --git a/test/test_modules.py b/test/test_modules.py index 791654e1bd3..e3a9ba3cb82 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -3,14 +3,19 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import argparse from numbers import Number +import functorch import pytest import torch from _utils_internal import get_available_devices -from torch import nn +from torch import nn, distributions as D from torchrl.data import TensorDict -from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec +from torchrl.data.tensor_specs import ( + OneHotDiscreteTensorSpec, + CustomNdOneHotDiscreteTensorSpec, +) from torchrl.modules import ( QValueActor, ActorValueOperator, @@ -18,7 +23,8 @@ ValueOperator, ProbabilisticActor, ) -from torchrl.modules.models import NoisyLinear, MLP, NoisyLazyLinear +from torchrl.modules.distributions import OneHotCategorical +from torchrl.modules.models import NoisyLinear, MLP, NoisyLazyLinear, MaskedLogitPolicy @pytest.mark.parametrize("in_features", [3, 10, None]) @@ -175,5 +181,48 @@ def test_actorcritic(device): ) == len(policy_params) +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("functional", [True, False]) +@pytest.mark.parametrize( + "mask", [torch.tensor([False, True, False, True, False]), "random"] +) +def test_maskedlogit(device, functional, mask): + batch = 10 + torch.manual_seed(0) + if isinstance(mask, str): + random_mask = True + mask = torch.zeros(batch, 5, dtype=torch.bool, device=device).bernoulli_() + else: + random_mask = False + mask = mask.to(device) + policy_net = nn.Linear(3, 5, bias=False).to(device) # model that returns logits + + policy_net_wrapped = MaskedLogitPolicy(policy_net) + if functional: + policy_net_wrapped, params = functorch.make_functional(policy_net_wrapped) + observation = torch.randn(batch, 3, device=device) + if functional: + logits_masked = policy_net_wrapped(params, observation, mask) + else: + logits_masked = policy_net_wrapped(observation, mask) + c = D.Categorical(logits=logits_masked) + samples = c.sample((1000,)) + samples_uniques = samples.unique() + if random_mask: + mask_expand = mask.expand(1000, batch, 5) + assert mask_expand.gather(-1, samples.unsqueeze(-1)).all() + else: + assert ((samples_uniques == 1) | (samples_uniques == 3)).all() + + # test synergy with CustomNdOneHotDiscreteTensorSpec + spec = CustomNdOneHotDiscreteTensorSpec(mask=mask) + c = OneHotCategorical(logits=logits_masked) + assert spec.is_in(c.sample((1000,))) + c = OneHotCategorical(logits=torch.randn_like(logits_masked)) + with pytest.raises(AssertionError): + assert spec.is_in(c.sample((1000,))) + + if __name__ == "__main__": - pytest.main([__file__]) + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_postprocs.py b/test/test_postprocs.py index 913a00418e7..98d95b9a0e1 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import argparse + import pytest import torch from torchrl.collectors.utils import split_trajectories @@ -164,4 +166,5 @@ def test_splits(self, num_workers, traj_len): if __name__ == "__main__": - pytest.main([__file__, "--capture", "no"]) + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index a7ce40fe045..8eb796abe58 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import argparse + import numpy as np import pytest import torch @@ -15,6 +17,7 @@ BoundedTensorSpec, UnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, + CustomNdOneHotDiscreteTensorSpec, ) @@ -226,6 +229,29 @@ def test_mult_onehot(shape, ns): assert (ts.encode(np_r) == r).all() +@pytest.mark.parametrize("n", range(10, 12)) +@pytest.mark.parametrize("shape", [torch.Size([]), torch.Size([10])]) +def test_custom_ndonehot(n, shape): + torch.manual_seed(0) + np.random.seed(0) + + with pytest.raises(RuntimeError): + mask = torch.zeros(*shape, n).bernoulli_() + ts = CustomNdOneHotDiscreteTensorSpec(mask) + mask = torch.zeros(*shape, n, dtype=torch.bool).bernoulli_() + ts = CustomNdOneHotDiscreteTensorSpec(mask) + + for _ in range(100): + r = ts.rand([10]) + assert r.shape == torch.Size([10, *shape, n]) + assert ts.is_in(r), r + assert ((r == 0) | (r == 1)).all() + r_numpy = r.argmax(-1).numpy() + assert (ts.encode(r_numpy) == r).all() + assert (ts.encode(ts.to_numpy(r)) == r).all() + assert (r.sum(-1) == 1).all() + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) @pytest.mark.parametrize( "shape", @@ -260,4 +286,5 @@ def test_composite(shape, dtype): if __name__ == "__main__": - pytest.main([__file__]) + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 76e626ee5c7..bad5a5e5dbd 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -33,6 +33,8 @@ "NdUnboundedContinuousTensorSpec", "BinaryDiscreteTensorSpec", "MultOneHotDiscreteTensorSpec", + "NdOneHotDiscreteTensorSpec", + "CustomNdOneHotDiscreteTensorSpec", "CompositeSpec", ] @@ -780,6 +782,99 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return torch.cat([super()._project(_val) for _val in vals], -1) +@dataclass(repr=False) +class NdOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): + """An N-dimensional One hot discrete tensor spec data class""" + + def __init__( + self, + n: int, + *shape: int, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[str, torch.dtype]] = torch.long, + use_register: bool = False, + ): + dtype, device = _default_dtype_and_device(dtype, device) + self.use_register = use_register + space = DiscreteBox( + n, + ) + self.shape = shape + total_shape = torch.Size( + ( + *shape, + space.n, + ) + ) + super(OneHotDiscreteTensorSpec, self).__init__( + total_shape, space, device, dtype, "discrete" + ) + + def rand(self, shape=torch.Size([])) -> torch.Tensor: + return torch.nn.functional.gumbel_softmax( + torch.rand(*shape, self.d, self.space.n, device=self.device), + hard=True, + dim=-1, + ).to(torch.long) + + +@dataclass(repr=False) +class CustomNdOneHotDiscreteTensorSpec(NdOneHotDiscreteTensorSpec): + """A masked N-dimensional One-Hot discrete tensor spec data-class + + The aim of this class is to check / project or document a discrete space + when it varies from environment to environment, or from step to step in the + same environment. + + """ + + def __init__( + self, + mask: torch.Tensor, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[str, torch.dtype]] = torch.long, + use_register: bool = False, + ): + if mask.dtype is not torch.bool: + raise RuntimeError( + f"Expected a mask with dtype torch.bool but got {mask.dtype}" + ) + if (mask.sum(-1) == 0).any(): + raise RuntimeError("Got an empty mask for some dimension.") + self.mask = mask + *shape, n = mask.shape + + dtype, device = _default_dtype_and_device(dtype, device) + self.use_register = use_register + space = DiscreteBox( + n, + ) + self.shape = shape + total_shape = torch.Size( + ( + *shape, + space.n, + ) + ) + super(OneHotDiscreteTensorSpec, self).__init__( + total_shape, space, device, dtype, "discrete" + ) + + def to(self, dest): + out = super().to(dest) + out.mask = self.mask.to(dest) + return out + + def rand(self, shape=torch.Size([])) -> torch.Tensor: + mask = self.mask.expand(*shape, *self.mask.shape) + r = torch.rand(mask.shape, device=mask.device).masked_fill_(~mask, 0.0) + return (r == r.max(-1, keepdim=True)[0]).to(torch.long) + + def is_in(self, value): + congruent = self.mask & value.to(torch.bool) + return (congruent.sum(-1) == 1).all() + + class CompositeSpec(TensorSpec): """ A composition of TensorSpecs. diff --git a/torchrl/envs/libs/smac.py b/torchrl/envs/libs/smac.py new file mode 100644 index 00000000000..7279a7714b5 --- /dev/null +++ b/torchrl/envs/libs/smac.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional + +import numpy as np +import torch +from smac.env import StarCraft2Env + +from torchrl.data import ( + TensorDict, + NdUnboundedContinuousTensorSpec, + UnboundedContinuousTensorSpec, +) +from torchrl.data.tensor_specs import ( + CustomNdOneHotDiscreteTensorSpec, +) +from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.envs.common import GymLikeEnv + + +class SCEnv(GymLikeEnv): + available_envs = ["8m"] + # TODO: add to parent class + supplementary_keys = ["available_actions"] + + @property + def observation_spec(self): + info = self._env.get_env_info() + dim = (info["n_agents"], info["obs_shape"]) + return NdUnboundedContinuousTensorSpec(dim) + + @property + def action_spec(self): + # info = self._env.get_env_info() + return CustomNdOneHotDiscreteTensorSpec( + torch.tensor(self._env.get_avail_actions()) + ) + + @property + def reward_spec(self): + return UnboundedContinuousTensorSpec() + + def _build_env(self, map_name: str, taskname=None, **kwargs) -> None: + if taskname: + raise RuntimeError + + env = StarCraft2Env(map_name=map_name) + self._env = env + return env + + def _output_transform(self, step_result): + reward, done, *other = step_result + obs = self._env.get_obs() + available_actions = self._env.get_avail_actions() + return (obs, reward, done, available_actions, *other) + + def _reset( + self, tensor_dict: Optional[_TensorDict] = None, **kwargs + ) -> _TensorDict: + obs = self._env.get_obs() + + tensor_dict_out = TensorDict( + source=self._read_obs(np.array(obs)), batch_size=self.batch_size + ) + self._is_done = torch.zeros(1, dtype=torch.bool) + tensor_dict_out.set("done", self._is_done) + available_actions = self._env.get_avail_actions() + tensor_dict_out.set("available_actions", available_actions) + return tensor_dict_out + + def _init_env(self, seed=None): + self._env.reset() + if seed is not None: + self.set_seed(seed) + + # TODO: check that actions match avail + def _action_transform(self, action): + action_np = self.action_spec.to_numpy(action) + return action_np + + # TODO: move to GymLike + def _step(self, tensor_dict: _TensorDict) -> _TensorDict: + action = tensor_dict.get("action") + action_np = self._action_transform(action) + + reward = 0.0 + for _ in range(self.wrapper_frame_skip): + obs, _reward, done, *info = self._output_transform( + self._env.step(action_np) + ) + if _reward is None: + _reward = 0.0 + reward += _reward + if done: + break + + obs_dict = self._read_obs(np.array(obs)) + + if reward is None: + reward = np.nan + reward = self._to_tensor(reward, dtype=self.reward_spec.dtype) + done = self._to_tensor(done, dtype=torch.bool) + self._is_done = done + self._current_tensordict = obs_dict + + tensor_dict_out = TensorDict({}, batch_size=tensor_dict.batch_size) + for key, value in obs_dict.items(): + tensor_dict_out.set(f"next_{key}", value) + tensor_dict_out.set("reward", reward) + tensor_dict_out.set("done", done) + for k, value in zip(self.supplementary_keys, info): + tensor_dict_out.set(k, value) + + return tensor_dict_out diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 21ad0e49e8e..ce664007b96 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -8,7 +8,7 @@ import numpy as np import torch -from torch import nn +from torch import nn, Tensor from torch.nn import functional as F from torchrl.modules.models.utils import ( @@ -28,6 +28,7 @@ "DdpgMlpActor", "DdpgMlpQNet", "LSTMNet", + "MaskedLogitPolicy", ] @@ -894,6 +895,55 @@ def forward(self, observation: torch.Tensor, action: torch.Tensor) -> torch.Tens return value +class MaskedLogitPolicy(nn.Module): + """Wrapper for logit output masking. + + This module takes as input a logit tensor and a mask, and returns a similar + logit tensor where invalid values have been masked out. + + This is aimed to be used in environments where the space of valid actions + is dynamic, or in settings where multiple tasks with different discrete + action space are trained using the same policy model. + + Examples: + >>> policy_net = nn.Linear(3, 5, bias=False) # model that returns logits + >>> mask = torch.tensor([False, True, False, True, False]) + >>> policy_net_wrapped = MaskedLogitPolicy(policy_net) + >>> observation = torch.zeros(2, 3) + >>> logits_masked = policy_net_wrapped(observation, mask) + >>> print(logits_masked) + (tensor([[-inf, 0., -inf, 0., -inf], + [-inf, 0., -inf, 0., -inf]], grad_fn=),) + + This class also supports functional modules: + >>> import functorch + >>> policy_net = nn.Linear(3, 5, bias=False) # model that returns logits + >>> mask = torch.tensor([False, True, False, True, False]) + >>> policy_net_wrapped = MaskedLogitPolicy(policy_net) + >>> fpolicy_net_wrapped, params = functorch.make_functional(policy_net_wrapped) + >>> observation = torch.zeros(2, 3) + >>> logits_masked = fpolicy_net_wrapped(params, observation, mask) + >>> print(logits_masked) + (tensor([[-inf, 0., -inf, 0., -inf], + [-inf, 0., -inf, 0., -inf]], grad_fn=),) + """ + + def __init__(self, policy_module): + super().__init__() + self.policy_module = policy_module + + def forward(self, *inputs, **kwargs): + *inputs, mask = inputs + outputs = self.policy_module(*inputs, **kwargs) + if isinstance(outputs, Tensor): + _outputs = outputs + else: + _outputs = outputs[0] + # first output is logits + _outputs.masked_fill_(~mask.to(torch.bool).expand_as(_outputs), -float("inf")) + return outputs + + class LSTMNet(nn.Module): """ An embedder for an LSTM followed by an MLP.