Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NOMRG] SMAC integration #56

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,15 @@ pip install gym gym[accept-rom-license] pygame gym_retro
pip install pytest
```

Alternatively, extra dependencies can be installed using
```
pip install ".[atari,dm_control,gym_continuous,rendering,tests,utils,smac]"
```
or a selection of these.

**N.B.**: SMAC (Starcraft Multiagent Contest) requires you to install Starcraft II.
Support on non-linux machines is limited. Please refer to the [original repo](https://github.com/oxwhirl/smac) for more information.

**Troubleshooting**

If a `ModuleNotFoundError: No module named ‘torchrl._torchrl` errors occurs, it means that the C++ extensions were not installed or not found.
Expand Down
111 changes: 111 additions & 0 deletions examples/smac/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Optional

import numpy as np
import torch
from smac.env import StarCraft2Env
from torchrl.data import (
TensorDict,
NdUnboundedContinuousTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.data.tensor_specs import (
CustomNdOneHotDiscreteTensorSpec,
)
from torchrl.data.tensordict.tensordict import _TensorDict
from torchrl.envs.common import GymLikeEnv


class SCEnv(GymLikeEnv):
available_envs = ["8m"]
# TODO: add to parent class
supplementary_keys = ["available_actions"]

@property
def observation_spec(self):
info = self._env.get_env_info()
dim = (info["n_agents"], info["obs_shape"])
return NdUnboundedContinuousTensorSpec(dim)

@property
def action_spec(self):
# info = self._env.get_env_info()
return CustomNdOneHotDiscreteTensorSpec(
torch.tensor(self._env.get_avail_actions())
)

@property
def reward_spec(self):
return UnboundedContinuousTensorSpec()

def _build_env(self, map_name: str, taskname=None, **kwargs) -> None:
if taskname:
raise RuntimeError

env = StarCraft2Env(map_name=map_name)
self._env = env
return env

def _output_transform(self, step_result):
reward, done, *other = step_result
obs = self._env.get_obs()
available_actions = self._env.get_avail_actions()
return (obs, reward, done, available_actions, *other)

def _reset(
self, tensor_dict: Optional[_TensorDict] = None, **kwargs
) -> _TensorDict:
obs = self._env.get_obs()

tensor_dict_out = TensorDict(
source=self._read_obs(np.array(obs)), batch_size=self.batch_size
)
self._is_done = torch.zeros(1, dtype=torch.bool)
tensor_dict_out.set("done", self._is_done)
available_actions = self._env.get_avail_actions()
tensor_dict_out.set("available_actions", available_actions)
return tensor_dict_out

def _init_env(self, seed=None):
self._env.reset()
if seed is not None:
self.set_seed(seed)

# TODO: check that actions match avail
def _action_transform(self, action):
action_np = self.action_spec.to_numpy(action)
return action_np

# TODO: move to GymLike
def _step(self, tensor_dict: _TensorDict) -> _TensorDict:
action = tensor_dict.get("action")
action_np = self._action_transform(action)

reward = 0.0
for _ in range(self.wrapper_frame_skip):
obs, _reward, done, *info = self._output_transform(
self._env.step(action_np)
)
if _reward is None:
_reward = 0.0
reward += _reward
if done:
break

obs_dict = self._read_obs(np.array(obs))

if reward is None:
reward = np.nan
reward = self._to_tensor(reward, dtype=self.reward_spec.dtype)
done = self._to_tensor(done, dtype=torch.bool)
self._is_done = done
self._current_tensordict = obs_dict

tensor_dict_out = TensorDict({}, batch_size=tensor_dict.batch_size)
for key, value in obs_dict.items():
tensor_dict_out.set(f"next_{key}", value)
tensor_dict_out.set("reward", reward)
tensor_dict_out.set("done", done)
for k, value in zip(self.supplementary_keys, info):
tensor_dict_out.set(k, value)

return tensor_dict_out
86 changes: 86 additions & 0 deletions examples/smac/smac_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from env import SCEnv
from torch import nn
from torchrl.agents.helpers import sync_async_collector
from torchrl.data import TensorDictPrioritizedReplayBuffer
from torchrl.envs import TransformedEnv, ObservationNorm
from torchrl.modules import (
ProbabilisticTDModule,
OneHotCategorical,
QValueActor,
MaskedLogitPolicy,
)

if __name__ == "__main__":
# create an env
env = SCEnv("8m")

# reset
td = env.reset()
print("tensordict after reset: ")
print(td)

# apply a sequence of transforms
env = TransformedEnv(env, ObservationNorm(0, 1, standard_normal=True))

# Get policy
policy = nn.LazyLinear(env.action_spec.shape[-1])
policy_wrap = MaskedLogitPolicy(policy)
policy_td_module = ProbabilisticTDModule(
module=policy_wrap,
spec=None,
in_keys=["observation", "available_actions"],
out_keys=["action"],
distribution_class=OneHotCategorical,
save_dist_params=True,
)

# Test the policy
policy_td_module(td)
print(td)
print("param: ", td.get("action_dist_param_0"))
print("action: ", td.get("action"))
print("mask: ", td.get("available_actions"))
print("mask from env: ", env.env._env.get_avail_actions())

# check that an ation can be performed in the env with this
env.step(td)
print(td)

# we can also have a regular Q-Value actor
print("\n\nQValue")
policy_td_module = QValueActor(
policy_wrap,
spec=None,
in_keys=["observation", "available_actions"],
# out_keys=["actions"]
)
td = env.reset()
policy_td_module(td)
print("action: ", td.get("action"))
env.step(td)
print("next_obs: ", td.get("next_observation"))

# now let's collect data, see MultiaSyncDataCollector for info
print("\n\nCollector")
collector = sync_async_collector(
env_fns=lambda: SCEnv("8m"),
env_kwargs=None,
num_collectors=4, # 4 main processes
num_env_per_collector=8, # each of the 4 collectors has 8 processes
policy=policy_td_module,
devices=["cuda:0"] * 4, # each collector will execute the policy on cuda
total_frames=1000, # we'd like to have a total of 1000 frames
max_frames_per_traj=10, # we'll reset after 10 steps
frames_per_batch=64, # each batch should have 64 frames
init_random_frames=0, # we won't execute random actions
)
print("replay buffer")
rb = TensorDictPrioritizedReplayBuffer(size=100, alpha=0.7, beta=1.1)
for td in collector:
print(f"collected tensordict has shape [Batch x Time]={td.shape}")
rb.extend(td.view(-1)) # we split each action
# rb.extend(td.unbind(0)) # we split each trajectory -- WIP

collector.update_policy_weights_() # if you have updated the local
# policy (on cpu) you may want to sync the collectors' policies to it
print("rb sample: ", rb.sample(2))
9 changes: 9 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ def _main():
"clean": clean,
},
install_requires=[pytorch_package_dep, "numpy"],
extras_require={
"atari": ["gym", "atari-py", "ale-py", "gym[accept-rom-license]", "pygame"],
"dm_control": ["dm_control"],
"gym_continuous": ["mujoco-py", "mujoco"],
"rendering": ["moviepy"],
"tests": ["pytest"],
"utils": ["tqdm", "configargparse"],
"smac": ["smac @ git+https://github.com/oxwhirl/smac.git"],
},
)


Expand Down
57 changes: 53 additions & 4 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,28 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
from numbers import Number

import functorch
import pytest
import torch
from _utils_internal import get_available_devices
from torch import nn
from torch import nn, distributions as D
from torchrl.data import TensorDict
from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec
from torchrl.data.tensor_specs import (
OneHotDiscreteTensorSpec,
CustomNdOneHotDiscreteTensorSpec,
)
from torchrl.modules import (
QValueActor,
ActorValueOperator,
TDModule,
ValueOperator,
ProbabilisticActor,
)
from torchrl.modules.models import NoisyLinear, MLP, NoisyLazyLinear
from torchrl.modules.distributions import OneHotCategorical
from torchrl.modules.models import NoisyLinear, MLP, NoisyLazyLinear, MaskedLogitPolicy


@pytest.mark.parametrize("in_features", [3, 10, None])
Expand Down Expand Up @@ -175,5 +181,48 @@ def test_actorcritic(device):
) == len(policy_params)


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("functional", [True, False])
@pytest.mark.parametrize(
"mask", [torch.tensor([False, True, False, True, False]), "random"]
)
def test_maskedlogit(device, functional, mask):
batch = 10
torch.manual_seed(0)
if isinstance(mask, str):
random_mask = True
mask = torch.zeros(batch, 5, dtype=torch.bool, device=device).bernoulli_()
else:
random_mask = False
mask = mask.to(device)
policy_net = nn.Linear(3, 5, bias=False).to(device) # model that returns logits

policy_net_wrapped = MaskedLogitPolicy(policy_net)
if functional:
policy_net_wrapped, params = functorch.make_functional(policy_net_wrapped)
observation = torch.randn(batch, 3, device=device)
if functional:
logits_masked = policy_net_wrapped(params, observation, mask)
else:
logits_masked = policy_net_wrapped(observation, mask)
c = D.Categorical(logits=logits_masked)
samples = c.sample((1000,))
samples_uniques = samples.unique()
if random_mask:
mask_expand = mask.expand(1000, batch, 5)
assert mask_expand.gather(-1, samples.unsqueeze(-1)).all()
else:
assert ((samples_uniques == 1) | (samples_uniques == 3)).all()

# test synergy with CustomNdOneHotDiscreteTensorSpec
spec = CustomNdOneHotDiscreteTensorSpec(mask=mask)
c = OneHotCategorical(logits=logits_masked)
assert spec.is_in(c.sample((1000,)))
c = OneHotCategorical(logits=torch.randn_like(logits_masked))
with pytest.raises(AssertionError):
assert spec.is_in(c.sample((1000,)))


if __name__ == "__main__":
pytest.main([__file__])
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
5 changes: 4 additions & 1 deletion test/test_postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse

import pytest
import torch
from torchrl.collectors.utils import split_trajectories
Expand Down Expand Up @@ -164,4 +166,5 @@ def test_splits(self, num_workers, traj_len):


if __name__ == "__main__":
pytest.main([__file__, "--capture", "no"])
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
29 changes: 28 additions & 1 deletion test/test_tensor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse

import numpy as np
import pytest
import torch
Expand All @@ -15,6 +17,7 @@
BoundedTensorSpec,
UnboundedContinuousTensorSpec,
OneHotDiscreteTensorSpec,
CustomNdOneHotDiscreteTensorSpec,
)


Expand Down Expand Up @@ -226,6 +229,29 @@ def test_mult_onehot(shape, ns):
assert (ts.encode(np_r) == r).all()


@pytest.mark.parametrize("n", range(10, 12))
@pytest.mark.parametrize("shape", [torch.Size([]), torch.Size([10])])
def test_custom_ndonehot(n, shape):
torch.manual_seed(0)
np.random.seed(0)

with pytest.raises(RuntimeError):
mask = torch.zeros(*shape, n).bernoulli_()
ts = CustomNdOneHotDiscreteTensorSpec(mask)
mask = torch.zeros(*shape, n, dtype=torch.bool).bernoulli_()
ts = CustomNdOneHotDiscreteTensorSpec(mask)

for _ in range(100):
r = ts.rand([10])
assert r.shape == torch.Size([10, *shape, n])
assert ts.is_in(r), r
assert ((r == 0) | (r == 1)).all()
r_numpy = r.argmax(-1).numpy()
assert (ts.encode(r_numpy) == r).all()
assert (ts.encode(ts.to_numpy(r)) == r).all()
assert (r.sum(-1) == 1).all()


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None])
@pytest.mark.parametrize(
"shape",
Expand Down Expand Up @@ -260,4 +286,5 @@ def test_composite(shape, dtype):


if __name__ == "__main__":
pytest.main([__file__])
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Loading