From 3c1241d8b82074c2c3cc1127ebae0fc9eb7d3ecf Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 26 Jan 2025 12:21:15 -0800 Subject: [PATCH] [Feature] ConditionalPolicySwitch transform ghstack-source-id: f147e7c6b0f55da5746f79563af66ad057021d66 Pull Request resolved: https://github.com/pytorch/rl/pull/2711 --- docs/source/reference/envs.rst | 1 + examples/agents/ppo-chess.py | 72 ++++++-- test/test_transforms.py | 203 ++++++++++++++++++++++ torchrl/envs/__init__.py | 1 + torchrl/envs/batched_envs.py | 11 +- torchrl/envs/custom/chess.py | 4 +- torchrl/envs/transforms/__init__.py | 1 + torchrl/envs/transforms/transforms.py | 241 ++++++++++++++++++++++++++ 8 files changed, 510 insertions(+), 24 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index c4f3f6eda9a..cc9231a5fa9 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -816,6 +816,7 @@ to be able to create this other composition: CenterCrop ClipTransform Compose + ConditionalPolicySwitch Crop DTypeCastTransform DeviceCastTransform diff --git a/examples/agents/ppo-chess.py b/examples/agents/ppo-chess.py index f9527339e2a..6c3a7886ee5 100644 --- a/examples/agents/ppo-chess.py +++ b/examples/agents/ppo-chess.py @@ -5,20 +5,24 @@ import tensordict.nn import torch import tqdm -from tensordict.nn import TensorDictSequential as TDSeq, TensorDictModule as TDMod, \ - ProbabilisticTensorDictModule as TDProb, ProbabilisticTensorDictSequential as TDProbSeq +from tensordict.nn import ( + ProbabilisticTensorDictModule as TDProb, + ProbabilisticTensorDictSequential as TDProbSeq, + TensorDictModule as TDMod, + TensorDictSequential as TDSeq, +) from torch import nn from torch.nn.utils import clip_grad_norm_ from torch.optim import Adam from torchrl.collectors import SyncDataCollector +from torchrl.data import LazyTensorStorage, ReplayBuffer, SamplerWithoutReplacement from torchrl.envs import ChessEnv, Tokenizer from torchrl.modules import MLP from torchrl.modules.distributions import MaskedCategorical from torchrl.objectives import ClipPPOLoss from torchrl.objectives.value import GAE -from torchrl.data import ReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement tensordict.nn.set_composite_lp_aggregate(False) @@ -39,7 +43,9 @@ embedding_moves = nn.Embedding(num_embeddings=n + 1, embedding_dim=64) # Embedding for the fen -embedding_fen = nn.Embedding(num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64) +embedding_fen = nn.Embedding( + num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64 +) backbone = MLP(out_features=512, num_cells=[512] * 8, activation_class=nn.ReLU) @@ -49,20 +55,30 @@ critic_head = nn.Linear(512, 1) critic_head.bias.data.fill_(0) -prob = TDProb(in_keys=["logits", "mask"], out_keys=["action"], distribution_class=MaskedCategorical, return_log_prob=True) +prob = TDProb( + in_keys=["logits", "mask"], + out_keys=["action"], + distribution_class=MaskedCategorical, + return_log_prob=True, +) + def make_mask(idx): mask = idx.new_zeros((*idx.shape[:-1], n + 1), dtype=torch.bool) return mask.scatter_(-1, idx, torch.ones_like(idx, dtype=torch.bool))[..., :-1] + actor = TDProbSeq( - TDMod( - make_mask, - in_keys=["legal_moves"], out_keys=["mask"]), + TDMod(make_mask, in_keys=["legal_moves"], out_keys=["mask"]), TDMod(embedding_moves, in_keys=["legal_moves"], out_keys=["embedded_legal_moves"]), TDMod(embedding_fen, in_keys=["fen_tokenized"], out_keys=["embedded_fen"]), - TDMod(lambda *args: torch.cat([arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1), in_keys=["embedded_legal_moves", "embedded_fen"], - out_keys=["features"]), + TDMod( + lambda *args: torch.cat( + [arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1 + ), + in_keys=["embedded_legal_moves", "embedded_fen"], + out_keys=["features"], + ), TDMod(backbone, in_keys=["features"], out_keys=["hidden"]), TDMod(actor_head, in_keys=["hidden"], out_keys=["logits"]), prob, @@ -78,7 +94,9 @@ def make_mask(idx): optim = Adam(loss.parameters()) -gae = GAE(value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True) +gae = GAE( + value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True +) # Create a data collector collector = SyncDataCollector( @@ -88,12 +106,20 @@ def make_mask(idx): total_frames=1_000_000, ) -replay_buffer0 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement()) -replay_buffer1 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement()) +replay_buffer0 = ReplayBuffer( + storage=LazyTensorStorage(max_size=collector.frames_per_batch // 2), + batch_size=batch_size, + sampler=SamplerWithoutReplacement(), +) +replay_buffer1 = ReplayBuffer( + storage=LazyTensorStorage(max_size=collector.frames_per_batch // 2), + batch_size=batch_size, + sampler=SamplerWithoutReplacement(), +) for data in tqdm.tqdm(collector): data = data.filter_non_tensor_data() - print('data', data[0::2]) + print("data", data[0::2]) for i in range(num_epochs): replay_buffer0.empty() replay_buffer1.empty() @@ -103,14 +129,24 @@ def make_mask(idx): # player 1 data1 = gae(data[1::2]) if i == 0: - print('win rate for 0', data0["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6)) - print('win rate for 1', data1["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6)) + print( + "win rate for 0", + data0["next", "reward"].sum() + / data["next", "done"].sum().clamp_min(1e-6), + ) + print( + "win rate for 1", + data1["next", "reward"].sum() + / data["next", "done"].sum().clamp_min(1e-6), + ) replay_buffer0.extend(data0) replay_buffer1.extend(data1) - n_iter = collector.frames_per_batch//(2 * batch_size) - for (d0, d1) in tqdm.tqdm(zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter): + n_iter = collector.frames_per_batch // (2 * batch_size) + for (d0, d1) in tqdm.tqdm( + zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter + ): loss_vals = (loss(d0) + loss(d1)) / 2 loss_vals.sum(reduce=True).backward() gn = clip_grad_norm_(loss.parameters(), 100.0) diff --git a/test/test_transforms.py b/test/test_transforms.py index b0d8bcfe8ef..d7429db5528 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -20,6 +20,8 @@ import tensordict.tensordict import torch +from tensordict.nn import WrapModule + from tensordict import ( NonTensorData, NonTensorStack, @@ -56,6 +58,7 @@ CenterCrop, ClipTransform, Compose, + ConditionalPolicySwitch, Crop, DeviceCastTransform, DiscreteActionProjection, @@ -13341,6 +13344,206 @@ def test_composite_reward_spec(self) -> None: assert transform.transform_reward_spec(reward_spec) == expected_reward_spec +class TestConditionalPolicySwitch(TransformBase): + def test_single_trans_env_check(self): + base_env = CountingEnv(max_steps=15) + condition = lambda td: ((td.get("step_count") % 2) == 0).all() + # Player 0 + policy_odd = lambda td: td.set("action", env.action_spec.zero()) + policy_even = lambda td: td.set("action", env.action_spec.one()) + transforms = Compose( + StepCounter(), + ConditionalPolicySwitch(condition=condition, policy=policy_even), + ) + env = base_env.append_transform(transforms) + env.check_env_specs() + + def _create_policy_odd(self, base_env): + return WrapModule( + lambda td, base_env=base_env: td.set( + "action", base_env.action_spec_unbatched.zero(td.shape) + ), + out_keys=["action"], + ) + + def _create_policy_even(self, base_env): + return WrapModule( + lambda td, base_env=base_env: td.set( + "action", base_env.action_spec_unbatched.one(td.shape) + ), + out_keys=["action"], + ) + + def _create_transforms(self, condition, policy_even): + return Compose( + StepCounter(), + ConditionalPolicySwitch(condition=condition, policy=policy_even), + ) + + def _make_env(self, max_count, env_cls): + torch.manual_seed(0) + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) + base_env = env_cls(max_steps=max_count) + policy_even = self._create_policy_even(base_env) + transforms = self._create_transforms(condition, policy_even) + return base_env.append_transform(transforms) + + def _test_env(self, env, policy_odd): + env.check_env_specs() + env.set_seed(0) + r = env.rollout(100, policy_odd, break_when_any_done=False) + # Check results are independent: one reset / step in one env should not impact results in another + r0, r1, r2 = r.unbind(0) + r0_split = r0.split(6) + assert all(((r == r0_split[0][: r.numel()]).all() for r in r0_split[1:])) + r1_split = r1.split(7) + assert all(((r == r1_split[0][: r.numel()]).all() for r in r1_split[1:])) + r2_split = r2.split(8) + assert all(((r == r2_split[0][: r.numel()]).all() for r in r2_split[1:])) + + def test_trans_serial_env_check(self): + torch.manual_seed(0) + base_env = SerialEnv( + 3, + [partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)], + batch_locked=False, + ) + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) + policy_odd = self._create_policy_odd(base_env) + policy_even = self._create_policy_even(base_env) + transforms = self._create_transforms(condition, policy_even) + env = base_env.append_transform(transforms) + self._test_env(env, policy_odd) + + def test_trans_parallel_env_check(self): + torch.manual_seed(0) + base_env = ParallelEnv( + 3, + [partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)], + batch_locked=False, + mp_start_method=mp_ctx, + ) + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) + policy_odd = self._create_policy_odd(base_env) + policy_even = self._create_policy_even(base_env) + transforms = self._create_transforms(condition, policy_even) + env = base_env.append_transform(transforms) + self._test_env(env, policy_odd) + + def test_serial_trans_env_check(self): + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) + policy_odd = self._create_policy_odd(CountingEnv()) + + def make_env(max_count): + return partial(self._make_env, max_count, CountingEnv) + + env = SerialEnv(3, [make_env(6), make_env(7), make_env(8)]) + self._test_env(env, policy_odd) + + def test_parallel_trans_env_check(self): + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) + policy_odd = self._create_policy_odd(CountingEnv()) + + def make_env(max_count): + return partial(self._make_env, max_count, CountingEnv) + + env = ParallelEnv( + 3, [make_env(6), make_env(7), make_env(8)], mp_start_method=mp_ctx + ) + self._test_env(env, policy_odd) + + def test_transform_no_env(self): + policy_odd = lambda td: td + policy_even = lambda td: td + condition = lambda td: True + transforms = ConditionalPolicySwitch(condition=condition, policy=policy_even) + with pytest.raises( + RuntimeError, + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", + ): + transforms(TensorDict()) + + def test_transform_compose(self): + policy_odd = lambda td: td + policy_even = lambda td: td + condition = lambda td: True + transforms = Compose( + ConditionalPolicySwitch(condition=condition, policy=policy_even), + ) + with pytest.raises( + RuntimeError, + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", + ): + transforms(TensorDict()) + + def test_transform_env(self): + base_env = CountingEnv(max_steps=15) + condition = lambda td: ((td.get("step_count") % 2) == 0).all() + # Player 0 + policy_odd = lambda td: td.set("action", env.action_spec.zero()) + policy_even = lambda td: td.set("action", env.action_spec.one()) + transforms = Compose( + StepCounter(), + ConditionalPolicySwitch(condition=condition, policy=policy_even), + ) + env = base_env.append_transform(transforms) + env.check_env_specs() + r = env.rollout(1000, policy_odd, break_when_all_done=True) + assert r.shape[0] == 15 + assert (r["action"] == 0).all() + assert ( + r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1) + ).all() + assert r["next", "done"].any() + + # Player 1 + condition = lambda td: ((td.get("step_count") % 2) == 1).all() + transforms = Compose( + StepCounter(), + ConditionalPolicySwitch(condition=condition, policy=policy_odd), + ) + env = base_env.append_transform(transforms) + r = env.rollout(1000, policy_even, break_when_all_done=True) + assert r.shape[0] == 16 + assert (r["action"] == 1).all() + assert ( + r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1) + ).all() + assert r["next", "done"].any() + + def test_transform_model(self): + policy_odd = lambda td: td + policy_even = lambda td: td + condition = lambda td: True + transforms = nn.Sequential( + ConditionalPolicySwitch(condition=condition, policy=policy_even), + ) + with pytest.raises( + RuntimeError, + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", + ): + transforms(TensorDict()) + + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + def test_transform_rb(self, rbclass): + policy_odd = lambda td: td + policy_even = lambda td: td + condition = lambda td: True + rb = rbclass(storage=LazyTensorStorage(10)) + rb.append_transform( + ConditionalPolicySwitch(condition=condition, policy=policy_even) + ) + rb.extend(TensorDict(batch_size=[2])) + with pytest.raises( + RuntimeError, + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", + ): + rb.sample(2) + + def test_transform_inverse(self): + return + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index fed73755502..52f3bb3ac1b 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -55,6 +55,7 @@ CenterCrop, ClipTransform, Compose, + ConditionalPolicySwitch, Crop, DeviceCastTransform, DiscreteActionProjection, diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 51331a86346..50a77a8f557 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -191,6 +191,8 @@ class BatchedEnvBase(EnvBase): one of the environment has dynamic specs. .. note:: Learn more about dynamic specs and environments :ref:`here `. + batch_locked (bool, optional): if provided, will override the ``batch_locked`` attribute of the + nested environments. `batch_locked=False` may allow for partial steps. .. note:: One can pass keyword arguments to each sub-environments using the following @@ -305,6 +307,7 @@ def __init__( non_blocking: bool = False, mp_start_method: str = None, use_buffers: bool = None, + batch_locked: bool | None = None, ): super().__init__(device=device) self.serial_for_single = serial_for_single @@ -344,6 +347,7 @@ def __init__( # if share_individual_td is None, we will assess later if the output can be stacked self.share_individual_td = share_individual_td + self._batch_locked = batch_locked self._share_memory = shared_memory self._memmap = memmap self.allow_step_when_done = allow_step_when_done @@ -610,8 +614,8 @@ def map_device(key, value, device_map=device_map): self._env_tensordict.named_apply( map_device, nested_keys=True, filter_empty=True ) - - self._batch_locked = meta_data.batch_locked + if self._batch_locked is None: + self._batch_locked = meta_data.batch_locked else: self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size]) devices = set() @@ -652,7 +656,8 @@ def map_device(key, value, device_map=device_map): self._env_tensordict = torch.stack( [meta_data.tensordict for meta_data in meta_data], 0 ) - self._batch_locked = meta_data[0].batch_locked + if self._batch_locked is None: + self._batch_locked = meta_data[0].batch_locked self.has_lazy_inputs = contains_lazy_spec(self.input_spec) def state_dict(self) -> OrderedDict: diff --git a/torchrl/envs/custom/chess.py b/torchrl/envs/custom/chess.py index 45d5e765d3b..1446d105ae9 100644 --- a/torchrl/envs/custom/chess.py +++ b/torchrl/envs/custom/chess.py @@ -158,9 +158,7 @@ class ChessEnv(EnvBase, metaclass=_HashMeta): batch_size=torch.Size([352]), device=None, is_shared=False) - - - """ + """ # noqa: D301 _hash_table: Dict[int, str] = {} _PGN_RESTART = """[Event "?"] diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 7ee142fe811..5f661cdee6e 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -20,6 +20,7 @@ CenterCrop, ClipTransform, Compose, + ConditionalPolicySwitch, Crop, DeviceCastTransform, DiscreteActionProjection, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index afab09d0fba..9b11834d24b 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -85,6 +85,7 @@ ) from torchrl.envs.utils import ( _sort_keys, + _terminated_or_truncated, _update_during_reset, make_composite_from_td, step_mdp, @@ -10093,3 +10094,243 @@ def _apply_transform(self, reward: Tensor) -> TensorDictBase: ) return (self.weights * reward).sum(dim=-1) + + +class ConditionalPolicySwitch(Transform): + """A transform that conditionally switches between policies based on a specified condition. + + This transform evaluates a condition on the data returned by the environment's `step` method. + If the condition is met, it applies a specified policy to the data. Otherwise, the data is + returned unaltered. This is useful for scenarios where different policies need to be applied + based on certain criteria, such as alternating turns in a game. + + Args: + policy (Callable[[TensorDictBase], TensorDictBase]): + The policy to be applied when the condition is met. This should be a callable that + takes a `TensorDictBase` and returns a `TensorDictBase`. + condition (Callable[[TensorDictBase], bool]): + A callable that takes a `TensorDictBase` and returns a boolean or a tensor indicating + whether the policy should be applied. + + .. warning:: This transform must have a parent environment. + + .. note:: Ideally, it should be the last transform in the stack. If the policy requires transformed + data (e.g., images), and this transform is applied before those transformations, the policy will + not receive the transformed data. + + Examples: + >>> import torch + >>> from tensordict.nn import TensorDictModule as Mod + >>> + >>> from torchrl.envs import GymEnv, ConditionalPolicySwitch, Compose, StepCounter + >>> # Create a CartPole environment. We'll be looking at the obs: if the first element of the obs is greater than + >>> # 0 (left position) we do a right action (action=0) using the switch policy. Otherwise, we use our main + >>> # policy which does a left action. + >>> base_env = GymEnv("CartPole-v1", categorical_action_encoding=True) + >>> + >>> policy = Mod(lambda: torch.ones((), dtype=torch.int64), in_keys=[], out_keys=["action"]) + >>> policy_switch = Mod(lambda: torch.zeros((), dtype=torch.int64), in_keys=[], out_keys=["action"]) + >>> + >>> cond = lambda td: td.get("observation")[..., 0] >= 0 + >>> + >>> env = base_env.append_transform( + ... Compose( + ... # We use two step counters to show that one counts the global steps, whereas the other + ... # only counts the steps where the main policy is executed + ... StepCounter(step_count_key="step_count_total"), + ... ConditionalPolicySwitch(condition=cond, policy=policy_switch), + ... StepCounter(step_count_key="step_count_main"), + ... ) + ... ) + >>> + >>> env.set_seed(0) + >>> torch.manual_seed(0) + >>> + >>> r = env.rollout(100, policy=policy) + >>> print("action", r["action"]) + action tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + >>> print("obs", r["observation"]) + obs tensor([[ 0.0322, -0.1540, 0.0111, 0.3190], + [ 0.0299, -0.1544, 0.0181, 0.3280], + [ 0.0276, -0.1550, 0.0255, 0.3414], + [ 0.0253, -0.1558, 0.0334, 0.3596], + [ 0.0230, -0.1569, 0.0422, 0.3828], + [ 0.0206, -0.1582, 0.0519, 0.4117], + [ 0.0181, -0.1598, 0.0629, 0.4469], + [ 0.0156, -0.1617, 0.0753, 0.4891], + [ 0.0130, -0.1639, 0.0895, 0.5394], + [ 0.0104, -0.1665, 0.1058, 0.5987], + [ 0.0076, -0.1696, 0.1246, 0.6685], + [ 0.0047, -0.1732, 0.1463, 0.7504], + [ 0.0016, -0.1774, 0.1715, 0.8459], + [-0.0020, 0.0150, 0.1884, 0.6117], + [-0.0017, 0.2071, 0.2006, 0.3838]]) + >>> print("obs'", r["next", "observation"]) + obs' tensor([[ 0.0299, -0.1544, 0.0181, 0.3280], + [ 0.0276, -0.1550, 0.0255, 0.3414], + [ 0.0253, -0.1558, 0.0334, 0.3596], + [ 0.0230, -0.1569, 0.0422, 0.3828], + [ 0.0206, -0.1582, 0.0519, 0.4117], + [ 0.0181, -0.1598, 0.0629, 0.4469], + [ 0.0156, -0.1617, 0.0753, 0.4891], + [ 0.0130, -0.1639, 0.0895, 0.5394], + [ 0.0104, -0.1665, 0.1058, 0.5987], + [ 0.0076, -0.1696, 0.1246, 0.6685], + [ 0.0047, -0.1732, 0.1463, 0.7504], + [ 0.0016, -0.1774, 0.1715, 0.8459], + [-0.0020, 0.0150, 0.1884, 0.6117], + [-0.0017, 0.2071, 0.2006, 0.3838], + [ 0.0105, 0.2015, 0.2115, 0.5110]]) + >>> print("total step count", r["step_count_total"].squeeze()) + total step count tensor([ 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 26, 27]) + >>> print("total step with main policy", r["step_count_main"].squeeze()) + total step with main policy tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]) + + """ + + def __init__( + self, + policy: Callable[[TensorDictBase], TensorDictBase], + condition: Callable[[TensorDictBase], bool], + ): + super().__init__([], []) + self.__dict__["policy"] = policy + self.condition = condition + + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: + cond = self.condition(next_tensordict) + if not isinstance(cond, (bool, torch.Tensor)): + raise RuntimeError( + "Calling the condition function should return a boolean or a tensor." + ) + elif isinstance(cond, (torch.Tensor,)): + if tuple(cond.shape) not in ((1,), (), tuple(tensordict.shape)): + raise RuntimeError( + "Tensor outputs must have the shape of the tensordict, or contain a single element." + ) + else: + cond = torch.tensor(cond, device=tensordict.device) + + if cond.any(): + step = tensordict.get("_step", cond) + if step.shape != cond.shape: + step = step.view_as(cond) + cond = cond & step + + parent: TransformedEnv = self.parent + any_done, done = self._check_done(next_tensordict) + next_td_save = None + if any_done: + if next_tensordict.numel() == 1 or done.all(): + return next_tensordict + if parent.base_env.batch_locked: + raise RuntimeError( + "Cannot run partial steps in a batched locked environment. " + "Hint: Parallel and Serial envs can be unlocked through a keyword argument in " + "the constructor." + ) + done = done.view(next_tensordict.shape) + cond = cond & ~done + if not cond.all(): + if parent.base_env.batch_locked: + raise RuntimeError( + "Cannot run partial steps in a batched locked environment. " + "Hint: Parallel and Serial envs can be unlocked through a keyword argument in " + "the constructor." + ) + next_td_save = next_tensordict + next_tensordict = next_tensordict[cond] + tensordict = tensordict[cond] + + # policy may be expensive or raise an exception when executed with unadequate data so + # we index the td first + td = self.policy( + parent.step_mdp(tensordict.copy().set("next", next_tensordict)) + ) + # Mark the partial steps if needed + if next_td_save is not None: + td_new = td.new_zeros(cond.shape) + # TODO: swap with masked_scatter when avail + td_new[cond] = td + td = td_new + td.set("_step", cond) + next_tensordict = parent._step(td) + if next_td_save is not None: + return torch.where(cond, next_tensordict, next_td_save) + return next_tensordict + return next_tensordict + + def _check_done(self, tensordict): + env = self.parent + if env._simple_done: + done = tensordict._get_str("done", default=None) + if done is not None: + any_done = done.any() + else: + any_done = False + else: + any_done = _terminated_or_truncated( + tensordict, + full_done_spec=env.output_spec["full_done_spec"], + key="_reset", + ) + done = tensordict.pop("_reset") + return any_done, done + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + cond = self.condition(tensordict_reset) + # TODO: move to validate + if not isinstance(cond, (bool, torch.Tensor)): + raise RuntimeError( + "Calling the condition function should return a boolean or a tensor." + ) + elif isinstance(cond, (torch.Tensor,)): + if tuple(cond.shape) not in ((1,), (), tuple(tensordict.shape)): + raise RuntimeError( + "Tensor outputs must have the shape of the tensordict, or contain a single element." + ) + else: + cond = torch.tensor(cond, device=tensordict.device) + + if cond.any(): + reset = tensordict.get("_reset", cond) + if reset.shape != cond.shape: + reset = reset.view_as(cond) + cond = cond & reset + + parent: TransformedEnv = self.parent + reset_td_save = None + if not cond.all(): + if parent.base_env.batch_locked: + raise RuntimeError( + "Cannot run partial steps in a batched locked environment. " + "Hint: Parallel and Serial envs can be unlocked through a keyword argument in " + "the constructor." + ) + reset_td_save = tensordict_reset.copy() + tensordict_reset = tensordict_reset[cond] + tensordict = tensordict[cond] + + td = self.policy(tensordict_reset) + # Mark the partial steps if needed + if reset_td_save is not None: + td_new = td.new_zeros(cond.shape) + # TODO: swap with masked_scatter when avail + td_new[cond] = td + td = td_new + td.set("_step", cond) + tensordict_reset = parent._step(td).exclude(*parent.reward_keys) + if reset_td_save is not None: + return torch.where(cond, tensordict_reset, reset_td_save) + return tensordict_reset + + return tensordict_reset + + def forward(self, tensordict: TensorDictBase) -> Any: + raise RuntimeError( + "ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional." + )