Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into minari
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Dec 1, 2023
2 parents 1f9c0d6 + d545364 commit 7247450
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 133 deletions.
38 changes: 18 additions & 20 deletions benchmarks/ecosystem/gym_env_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ def make(envname=envname, gym_backend=gym_backend):
# regular parallel env
for device in avail_devices:

def make(envname=envname, gym_backend=gym_backend, device=device):
def make(envname=envname, gym_backend=gym_backend):
with set_gym_backend(gym_backend):
return GymEnv(envname, device=device)
return GymEnv(envname, device="cpu")

# env_make = EnvCreator(make)
penv = ParallelEnv(num_workers, EnvCreator(make))
penv = ParallelEnv(num_workers, EnvCreator(make), device=device)
with torch.inference_mode():
# warmup
penv.rollout(2)
Expand All @@ -103,13 +103,13 @@ def make(envname=envname, gym_backend=gym_backend, device=device):

for device in avail_devices:

def make(envname=envname, gym_backend=gym_backend, device=device):
def make(envname=envname, gym_backend=gym_backend):
with set_gym_backend(gym_backend):
return GymEnv(envname, device=device)
return GymEnv(envname, device="cpu")

env_make = EnvCreator(make)
# penv = SerialEnv(num_workers, env_make)
penv = ParallelEnv(num_workers, env_make)
penv = ParallelEnv(num_workers, env_make, device=device)
collector = SyncDataCollector(
penv,
RandomPolicy(penv.action_spec),
Expand Down Expand Up @@ -164,14 +164,14 @@ def make_env(
for device in avail_devices:
# async collector
# + torchrl parallel env
def make_env(
envname=envname, gym_backend=gym_backend, device=device
):
def make_env(envname=envname, gym_backend=gym_backend):
with set_gym_backend(gym_backend):
return GymEnv(envname, device=device)
return GymEnv(envname, device="cpu")

penv = ParallelEnv(
num_workers // num_collectors, EnvCreator(make_env)
num_workers // num_collectors,
EnvCreator(make_env),
device=device,
)
collector = MultiaSyncDataCollector(
[penv] * num_collectors,
Expand Down Expand Up @@ -206,10 +206,9 @@ def make_env(
envname=envname,
num_workers=num_workers,
gym_backend=gym_backend,
device=device,
):
with set_gym_backend(gym_backend):
penv = GymEnv(envname, num_envs=num_workers, device=device)
penv = GymEnv(envname, num_envs=num_workers, device="cpu")
return penv

penv = EnvCreator(
Expand Down Expand Up @@ -247,14 +246,14 @@ def make_env(
for device in avail_devices:
# sync collector
# + torchrl parallel env
def make_env(
envname=envname, gym_backend=gym_backend, device=device
):
def make_env(envname=envname, gym_backend=gym_backend):
with set_gym_backend(gym_backend):
return GymEnv(envname, device=device)
return GymEnv(envname, device="cpu")

penv = ParallelEnv(
num_workers // num_collectors, EnvCreator(make_env)
num_workers // num_collectors,
EnvCreator(make_env),
device=device,
)
collector = MultiSyncDataCollector(
[penv] * num_collectors,
Expand Down Expand Up @@ -289,10 +288,9 @@ def make_env(
envname=envname,
num_workers=num_workers,
gym_backend=gym_backend,
device=device,
):
with set_gym_backend(gym_backend):
penv = GymEnv(envname, num_envs=num_workers, device=device)
penv = GymEnv(envname, num_envs=num_workers, device="cpu")
return penv

penv = EnvCreator(
Expand Down
25 changes: 15 additions & 10 deletions examples/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def transformed_env_constructor(
state_dim_gsde: Optional[int] = None,
batch_dims: Optional[int] = 0,
obs_norm_state_dict: Optional[dict] = None,
ignore_device: bool = False,
) -> Union[Callable, EnvCreator]:
"""
Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor.
Expand Down Expand Up @@ -179,6 +180,7 @@ def transformed_env_constructor(
it should be set to 1 (or the number of dims of the batch).
obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded
into the environment
ignore_device (bool, optional): if True, the device is ignored.
"""

def make_transformed_env(**kwargs) -> TransformedEnv:
Expand All @@ -189,14 +191,17 @@ def make_transformed_env(**kwargs) -> TransformedEnv:
from_pixels = cfg.from_pixels

if custom_env is None and custom_env_maker is None:
if isinstance(cfg.collector_device, str):
device = cfg.collector_device
elif isinstance(cfg.collector_device, Sequence):
device = cfg.collector_device[0]
if not ignore_device:
if isinstance(cfg.collector_device, str):
device = cfg.collector_device
elif isinstance(cfg.collector_device, Sequence):
device = cfg.collector_device[0]
else:
raise ValueError(
"collector_device must be either a string or a sequence of strings"
)
else:
raise ValueError(
"collector_device must be either a string or a sequence of strings"
)
device = None
env_kwargs = {
"env_name": env_name,
"device": device,
Expand Down Expand Up @@ -252,19 +257,19 @@ def parallel_env_constructor(
kwargs: keyword arguments for the `transformed_env_constructor` method.
"""
batch_transform = cfg.batch_transform
kwargs.update({"cfg": cfg, "use_env_creator": True})
if cfg.env_per_collector == 1:
kwargs.update({"cfg": cfg, "use_env_creator": True})
make_transformed_env = transformed_env_constructor(**kwargs)
return make_transformed_env
kwargs.update({"cfg": cfg, "use_env_creator": True})
make_transformed_env = transformed_env_constructor(
return_transformed_envs=not batch_transform, **kwargs
return_transformed_envs=not batch_transform, ignore_device=True, **kwargs
)
parallel_env = ParallelEnv(
num_workers=cfg.env_per_collector,
create_env_fn=make_transformed_env,
create_env_kwargs=None,
pin_memory=cfg.pin_memory,
device=cfg.collector_device,
)
if batch_transform:
kwargs.update(
Expand Down
42 changes: 42 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,48 @@ def test_mb_env_batch_lock(self, device, seed=0):


class TestParallel:
@pytest.mark.skipif(
not torch.cuda.device_count(), reason="No cuda device detected."
)
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("hetero", [True, False])
@pytest.mark.parametrize("pdevice", [None, "cpu", "cuda"])
@pytest.mark.parametrize("edevice", ["cpu", "cuda"])
@pytest.mark.parametrize("bwad", [True, False])
def test_parallel_devices(self, parallel, hetero, pdevice, edevice, bwad):
if parallel:
cls = ParallelEnv
else:
cls = SerialEnv
if not hetero:
env = cls(
2, lambda: ContinuousActionVecMockEnv(device=edevice), device=pdevice
)
else:
env1 = lambda: ContinuousActionVecMockEnv(device=edevice)
env2 = lambda: TransformedEnv(ContinuousActionVecMockEnv(device=edevice))
env = cls(2, [env1, env2], device=pdevice)

r = env.rollout(2, break_when_any_done=bwad)
if pdevice is not None:
assert env.device.type == torch.device(pdevice).type
assert r.device.type == torch.device(pdevice).type
assert all(
item.device.type == torch.device(pdevice).type
for item in r.values(True, True)
)
else:
assert env.device.type == torch.device(edevice).type
assert r.device.type == torch.device(edevice).type
assert all(
item.device.type == torch.device(edevice).type
for item in r.values(True, True)
)
if parallel:
assert (
env.shared_tensordict_parent.device.type == torch.device(edevice).type
)

@pytest.mark.parametrize("num_parallel_env", [1, 10])
@pytest.mark.parametrize("env_batch_size", [[], (32,), (32, 1), (32, 0)])
def test_env_with_batch_size(self, num_parallel_env, env_batch_size):
Expand Down
34 changes: 34 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import itertools
import pickle
import re
import sys
from copy import copy
from functools import partial
Expand Down Expand Up @@ -4878,6 +4879,39 @@ def test_sum_reward(self, keys, device):
def test_transform_inverse(self):
raise pytest.skip("No inverse for RewardSum")

@pytest.mark.parametrize("in_keys", [["reward"], ["reward_1", "reward_2"]])
@pytest.mark.parametrize(
"out_keys", [["episode_reward"], ["episode_reward_1", "episode_reward_2"]]
)
@pytest.mark.parametrize("reset_keys", [["_reset"], ["_reset1", "_reset2"]])
def test_keys_length_errors(self, in_keys, reset_keys, out_keys, batch=10):
reset_dict = {
reset_key: torch.zeros(batch, dtype=torch.bool) for reset_key in reset_keys
}
reward_sum_dict = {out_key: torch.randn(batch) for out_key in out_keys}
reset_dict.update(reward_sum_dict)
td = TensorDict(reset_dict, [])

if len(in_keys) != len(out_keys):
with pytest.raises(
ValueError,
match="RewardSum expects the same number of input and output keys",
):
RewardSum(in_keys=in_keys, reset_keys=reset_keys, out_keys=out_keys)
else:
t = RewardSum(in_keys=in_keys, reset_keys=reset_keys, out_keys=out_keys)

if len(in_keys) != len(reset_keys):
with pytest.raises(
ValueError,
match=re.escape(
f"Could not match the env reset_keys {reset_keys} with the in_keys {in_keys}"
),
):
t.reset(td)
else:
t.reset(td)


class TestReward2Go(TransformBase):
@pytest.mark.parametrize("device", get_default_devices())
Expand Down
26 changes: 10 additions & 16 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,10 @@ def _get_policy_and_device(
device = torch.device(device) if device is not None else policy_device
get_weights_fn = None
if policy_device != device:
param_and_buf = dict(policy.named_parameters())
param_and_buf.update(dict(policy.named_buffers()))
param_and_buf = TensorDict.from_module(policy, as_module=True)

def get_weights_fn(param_and_buf=param_and_buf):
return TensorDict(param_and_buf, []).apply(lambda x: x.data)
return param_and_buf.data

policy_cast = deepcopy(policy).requires_grad_(False).to(device)
# here things may break bc policy.to("cuda") gives us weights on cuda:0 (same
Expand All @@ -308,9 +307,9 @@ def update_policy_weights_(
"""
if policy_weights is not None:
self.policy_weights.apply(lambda x: x.data).update_(policy_weights)
self.policy_weights.data.update_(policy_weights)
elif self.get_weights_fn is not None:
self.policy_weights.apply(lambda x: x.data).update_(self.get_weights_fn())
self.policy_weights.data.update_(self.get_weights_fn())

def __iter__(self) -> Iterator[TensorDictBase]:
return self.iterator()
Expand Down Expand Up @@ -559,10 +558,7 @@ def __init__(
)

if isinstance(self.policy, nn.Module):
self.policy_weights = TensorDict(dict(self.policy.named_parameters()), [])
self.policy_weights.update(
TensorDict(dict(self.policy.named_buffers()), [])
)
self.policy_weights = TensorDict.from_module(self.policy, as_module=True)
else:
self.policy_weights = TensorDict({}, [])

Expand Down Expand Up @@ -1200,9 +1196,9 @@ def device_err_msg(device_name, devices_list):
)
self._policy_dict[_device] = _policy
if isinstance(_policy, nn.Module):
param_dict = dict(_policy.named_parameters())
param_dict.update(_policy.named_buffers())
self._policy_weights_dict[_device] = TensorDict(param_dict, [])
self._policy_weights_dict[_device] = TensorDict.from_module(
_policy, as_module=True
)
else:
self._policy_weights_dict[_device] = TensorDict({}, [])

Expand Down Expand Up @@ -1288,11 +1284,9 @@ def frames_per_batch_worker(self):
def update_policy_weights_(self, policy_weights=None) -> None:
for _device in self._policy_dict:
if policy_weights is not None:
self._policy_weights_dict[_device].apply(lambda x: x.data).update_(
policy_weights
)
self._policy_weights_dict[_device].data.update_(policy_weights)
elif self._get_weights_fn_dict[_device] is not None:
self._policy_weights_dict[_device].update_(
self._policy_weights_dict[_device].data.update_(
self._get_weights_fn_dict[_device]()
)

Expand Down
Loading

0 comments on commit 7247450

Please sign in to comment.