Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 26, 2025
2 parents eb1d63d + b10796e commit 9ad9fcc
Show file tree
Hide file tree
Showing 7 changed files with 657 additions and 304 deletions.
33 changes: 33 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,39 @@ def make_env():
del env


@pytest.mark.parametrize(
"break_when_any_done,break_when_all_done",
[[True, False], [False, True], [False, False]],
)
@pytest.mark.parametrize("n_envs", [1, 4])
def test_collector_outplace_policy(n_envs, break_when_any_done, break_when_all_done):
def policy_inplace(td):
td.set("action", torch.ones(td.shape + (1,)))
return td

def policy_outplace(td):
return td.empty().set("action", torch.ones(td.shape + (1,)))

if n_envs == 1:
env = CountingEnv(10)
else:
env = SerialEnv(
n_envs,
[functools.partial(CountingEnv, 10 + i) for i in range(n_envs)],
)
env.reset()
c_inplace = SyncDataCollector(
env, policy_inplace, frames_per_batch=10, total_frames=100
)
d_inplace = torch.cat(list(c_inplace), dim=0)
env.reset()
c_outplace = SyncDataCollector(
env, policy_outplace, frames_per_batch=10, total_frames=100
)
d_outplace = torch.cat(list(c_outplace), dim=0)
assert_allclose_td(d_inplace, d_outplace)


# Deprecated reset_when_done
# @pytest.mark.parametrize("num_env", [1, 2])
# @pytest.mark.parametrize("env_name", ["vec"])
Expand Down
313 changes: 184 additions & 129 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,109 +250,200 @@ def test_env_seed(env_name, frame_skip, seed=0):
env.close()


@pytest.mark.skipif(not _has_gym, reason="no gym")
@pytest.mark.parametrize("env_name", [PENDULUM_VERSIONED, PONG_VERSIONED])
@pytest.mark.parametrize("frame_skip", [1, 4])
def test_rollout(env_name, frame_skip, seed=0):
if env_name is PONG_VERSIONED and version.parse(
gym_backend().__version__
) < version.parse("0.19"):
# Then 100 steps in pong are not sufficient to detect a difference
pytest.skip("can't detect difference in gym rollout with this gym version.")
class TestRollout:
@pytest.mark.skipif(not _has_gym, reason="no gym")
@pytest.mark.parametrize("env_name", [PENDULUM_VERSIONED, PONG_VERSIONED])
@pytest.mark.parametrize("frame_skip", [1, 4])
def test_rollout(self, env_name, frame_skip, seed=0):
if env_name is PONG_VERSIONED and version.parse(
gym_backend().__version__
) < version.parse("0.19"):
# Then 100 steps in pong are not sufficient to detect a difference
pytest.skip("can't detect difference in gym rollout with this gym version.")

env_name = env_name()
env = GymEnv(env_name, frame_skip=frame_skip)
env_name = env_name()
env = GymEnv(env_name, frame_skip=frame_skip)

torch.manual_seed(seed)
np.random.seed(seed)
env.set_seed(seed)
env.reset()
rollout1 = env.rollout(max_steps=100)
assert rollout1.names[-1] == "time"
torch.manual_seed(seed)
np.random.seed(seed)
env.set_seed(seed)
env.reset()
rollout1 = env.rollout(max_steps=100)
assert rollout1.names[-1] == "time"

torch.manual_seed(seed)
np.random.seed(seed)
env.set_seed(seed)
env.reset()
rollout2 = env.rollout(max_steps=100)
assert rollout2.names[-1] == "time"
torch.manual_seed(seed)
np.random.seed(seed)
env.set_seed(seed)
env.reset()
rollout2 = env.rollout(max_steps=100)
assert rollout2.names[-1] == "time"

assert_allclose_td(rollout1, rollout2)
assert_allclose_td(rollout1, rollout2)

torch.manual_seed(seed)
env.set_seed(seed + 10)
env.reset()
rollout3 = env.rollout(max_steps=100)
with pytest.raises(AssertionError):
assert_allclose_td(rollout1, rollout3)
env.close()
torch.manual_seed(seed)
env.set_seed(seed + 10)
env.reset()
rollout3 = env.rollout(max_steps=100)
with pytest.raises(AssertionError):
assert_allclose_td(rollout1, rollout3)
env.close()

def test_rollout_set_truncated(self):
env = ContinuousActionVecMockEnv()
with pytest.raises(RuntimeError, match="set_truncated was set to True"):
env.rollout(max_steps=10, set_truncated=True, break_when_any_done=False)
env.add_truncated_keys()
r = env.rollout(max_steps=10, set_truncated=True, break_when_any_done=False)
assert r.shape == torch.Size([10])
assert r[..., -1]["next", "truncated"].all()
assert r[..., -1]["next", "done"].all()

@pytest.mark.parametrize("max_steps", [1, 5])
def test_rollouts_chaining(self, max_steps, batch_size=(4,), epochs=4):
# CountingEnv is done at max_steps + 1, so to emulate it being done at max_steps, we feed max_steps=max_steps - 1
env = CountingEnv(max_steps=max_steps - 1, batch_size=batch_size)
policy = CountingEnvCountPolicy(
action_spec=env.action_spec, action_key=env.action_key
)

input_td = env.reset()
for _ in range(epochs):
rollout_td = env.rollout(
max_steps=max_steps,
policy=policy,
auto_reset=False,
break_when_any_done=False,
tensordict=input_td,
)
assert (env.count == max_steps).all()
input_td = step_mdp(
rollout_td[..., -1],
keep_other=True,
exclude_action=False,
exclude_reward=True,
reward_keys=env.reward_keys,
action_keys=env.action_keys,
done_keys=env.done_keys,
)

def test_rollout_set_truncated():
env = ContinuousActionVecMockEnv()
with pytest.raises(RuntimeError, match="set_truncated was set to True"):
env.rollout(max_steps=10, set_truncated=True, break_when_any_done=False)
env.add_truncated_keys()
r = env.rollout(max_steps=10, set_truncated=True, break_when_any_done=False)
assert r.shape == torch.Size([10])
assert r[..., -1]["next", "truncated"].all()
assert r[..., -1]["next", "done"].all()


@pytest.mark.parametrize("max_steps", [1, 5])
def test_rollouts_chaining(max_steps, batch_size=(4,), epochs=4):
# CountingEnv is done at max_steps + 1, so to emulate it being done at max_steps, we feed max_steps=max_steps - 1
env = CountingEnv(max_steps=max_steps - 1, batch_size=batch_size)
policy = CountingEnvCountPolicy(
action_spec=env.action_spec, action_key=env.action_key
)
@pytest.mark.parametrize("device", get_default_devices())
def test_rollout_predictability(self, device):
env = MockSerialEnv(device=device)
env.set_seed(100)
first = 100 % 17
policy = Actor(torch.nn.Linear(1, 1, bias=False)).to(device)
for p in policy.parameters():
p.data.fill_(1.0)
td_out = env.rollout(policy=policy, max_steps=200)
assert (
torch.arange(first, first + 100, device=device)
== td_out.get("observation").squeeze()
).all()
assert (
torch.arange(first + 1, first + 101, device=device)
== td_out.get(("next", "observation")).squeeze()
).all()
assert (
torch.arange(first + 1, first + 101, device=device)
== td_out.get(("next", "reward")).squeeze()
).all()
assert (
torch.arange(first, first + 100, device=device)
== td_out.get("action").squeeze()
).all()

input_td = env.reset()
for _ in range(epochs):
rollout_td = env.rollout(
max_steps=max_steps,
policy=policy,
auto_reset=False,
break_when_any_done=False,
tensordict=input_td,
)
assert (env.count == max_steps).all()
input_td = step_mdp(
rollout_td[..., -1],
keep_other=True,
exclude_action=False,
exclude_reward=True,
reward_keys=env.reward_keys,
action_keys=env.action_keys,
done_keys=env.done_keys,
)
@pytest.mark.skipif(not _has_gym, reason="no gym")
@pytest.mark.parametrize("env_name", [PENDULUM_VERSIONED])
@pytest.mark.parametrize("frame_skip", [1])
@pytest.mark.parametrize("truncated_key", ["truncated", "done"])
@pytest.mark.parametrize("parallel", [False, True])
def test_rollout_reset(
self,
env_name,
frame_skip,
parallel,
truncated_key,
maybe_fork_ParallelEnv,
seed=0,
):
env_name = env_name()
envs = []
for horizon in [20, 30, 40]:
envs.append(
lambda horizon=horizon: TransformedEnv(
GymEnv(env_name, frame_skip=frame_skip),
StepCounter(horizon, truncated_key=truncated_key),
)
)
if parallel:
env = maybe_fork_ParallelEnv(3, envs)
else:
env = SerialEnv(3, envs)
env.set_seed(100)
out = env.rollout(100, break_when_any_done=False)
assert out.names[-1] == "time"
assert out.shape == torch.Size([3, 100])
assert (
out[..., -1]["step_count"].squeeze().cpu() == torch.tensor([19, 9, 19])
).all()
assert (
out[..., -1]["next", "step_count"].squeeze().cpu()
== torch.tensor([20, 10, 20])
).all()
assert (
out["next", truncated_key].squeeze().sum(-1) == torch.tensor([5, 3, 2])
).all()

@pytest.mark.parametrize(
"break_when_any_done,break_when_all_done",
[[True, False], [False, True], [False, False]],
)
@pytest.mark.parametrize("n_envs,serial", [[1, None], [4, True], [4, False]])
def test_rollout_outplace_policy(
self, n_envs, serial, break_when_any_done, break_when_all_done
):
def policy_inplace(td):
td.set("action", torch.ones(td.shape + (1,)))
return td

@pytest.mark.parametrize("device", get_default_devices())
def test_rollout_predictability(device):
env = MockSerialEnv(device=device)
env.set_seed(100)
first = 100 % 17
policy = Actor(torch.nn.Linear(1, 1, bias=False)).to(device)
for p in policy.parameters():
p.data.fill_(1.0)
td_out = env.rollout(policy=policy, max_steps=200)
assert (
torch.arange(first, first + 100, device=device)
== td_out.get("observation").squeeze()
).all()
assert (
torch.arange(first + 1, first + 101, device=device)
== td_out.get(("next", "observation")).squeeze()
).all()
assert (
torch.arange(first + 1, first + 101, device=device)
== td_out.get(("next", "reward")).squeeze()
).all()
assert (
torch.arange(first, first + 100, device=device)
== td_out.get("action").squeeze()
).all()
def policy_outplace(td):
return td.empty().set("action", torch.ones(td.shape + (1,)))

if n_envs == 1:
env = CountingEnv(10)
elif serial:
env = SerialEnv(
n_envs,
[partial(CountingEnv, 10 + i) for i in range(n_envs)],
)
else:
env = ParallelEnv(
n_envs,
[partial(CountingEnv, 10 + i) for i in range(n_envs)],
mp_start_method=mp_ctx,
)
r_inplace = env.rollout(
40,
policy_inplace,
break_when_all_done=break_when_all_done,
break_when_any_done=break_when_any_done,
)
r_outplace = env.rollout(
40,
policy_outplace,
break_when_all_done=break_when_all_done,
break_when_any_done=break_when_any_done,
)
if break_when_any_done:
assert r_outplace.shape[-1:] == (11,)
elif break_when_all_done:
if n_envs > 1:
assert r_outplace.shape[-1:] == (14,)
else:
assert r_outplace.shape[-1:] == (11,)
else:
assert r_outplace.shape[-1:] == (40,)
assert_allclose_td(r_inplace, r_outplace)


# Check that the "terminated" key is filled in automatically if only the "done"
Expand Down Expand Up @@ -411,42 +502,6 @@ def _step(
assert torch.equal(td[("next", "terminated")], torch.tensor([[True], [False]]))


@pytest.mark.skipif(not _has_gym, reason="no gym")
@pytest.mark.parametrize("env_name", [PENDULUM_VERSIONED])
@pytest.mark.parametrize("frame_skip", [1])
@pytest.mark.parametrize("truncated_key", ["truncated", "done"])
@pytest.mark.parametrize("parallel", [False, True])
def test_rollout_reset(
env_name, frame_skip, parallel, truncated_key, maybe_fork_ParallelEnv, seed=0
):
env_name = env_name()
envs = []
for horizon in [20, 30, 40]:
envs.append(
lambda horizon=horizon: TransformedEnv(
GymEnv(env_name, frame_skip=frame_skip),
StepCounter(horizon, truncated_key=truncated_key),
)
)
if parallel:
env = maybe_fork_ParallelEnv(3, envs)
else:
env = SerialEnv(3, envs)
env.set_seed(100)
out = env.rollout(100, break_when_any_done=False)
assert out.names[-1] == "time"
assert out.shape == torch.Size([3, 100])
assert (
out[..., -1]["step_count"].squeeze().cpu() == torch.tensor([19, 9, 19])
).all()
assert (
out[..., -1]["next", "step_count"].squeeze().cpu() == torch.tensor([20, 10, 20])
).all()
assert (
out["next", truncated_key].squeeze().sum(-1) == torch.tensor([5, 3, 2])
).all()


class TestModelBasedEnvBase:
@staticmethod
def world_model():
Expand Down
Loading

0 comments on commit 9ad9fcc

Please sign in to comment.