From 435584e1aa78819d1aa6a272f83d2c825d640c79 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 7 Sep 2023 16:11:20 +0100 Subject: [PATCH 1/5] [Minor] More efficient SAC v1 (#1507) --- torchrl/objectives/sac.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index a82795ab1bb..de4908d1335 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -410,9 +410,9 @@ def target_entropy(self): else: action_container_shape = action_spec.shape target_entropy = -float( - action_spec[self.tensor_keys.action].shape[ - len(action_container_shape) : - ].numel() + action_spec[self.tensor_keys.action] + .shape[len(action_container_shape) :] + .numel() ) self.register_buffer( "target_entropy_buffer", torch.tensor(target_entropy, device=device) @@ -619,24 +619,23 @@ def _qvalue_v1_loss( f"Batch size={tensordict.shape} is incompatible " f"with num_qvqlue_nets={self.num_qvalue_nets}." ) - tensordict_chunks = torch.stack( - tensordict.chunk(self.num_qvalue_nets, dim=0), 0 + tensordict_chunks = tensordict.reshape( + self.num_qvalue_nets, -1, *tensordict.shape[1:] + ) + target_chunks = target_value.reshape( + self.num_qvalue_nets, -1, *target_value.shape[1:] ) - target_chunks = torch.stack(target_value.chunk(self.num_qvalue_nets, dim=0), 0) # if vmap=True, it is assumed that the input tensordict must be cast to the param shape tensordict_chunks = self._vmap_qnetwork00( tensordict_chunks, self.qvalue_network_params ) - pred_val = tensordict_chunks.get(self.tensor_keys.state_action_value).squeeze( - -1 - ) + pred_val = tensordict_chunks.get(self.tensor_keys.state_action_value) + pred_val = pred_val.squeeze(-1) loss_value = distance_loss( pred_val, target_chunks, loss_function=self.loss_function ).view(*shape) - metadata = { - "td_error": torch.cat((pred_val - target_chunks).pow(2).unbind(0), 0) - } + metadata = {"td_error": (pred_val - target_chunks).pow(2).flatten(0, 1)} return loss_value, metadata From d2e11bfe3637145748708bde1ae58a36d2f57541 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 7 Sep 2023 16:42:00 +0100 Subject: [PATCH 2/5] [BugFix] Fix ClipTransform device (#1508) --- test/test_transforms.py | 24 ++++++++++++++++++------ torchrl/envs/transforms/transforms.py | 4 ++-- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index fc2b74676c7..5c736dcedf0 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -383,10 +383,11 @@ def test_transform_compose(self): assert data["reward"] == 2 assert data["reward_clip"] == 0.1 - def test_transform_env(self): - env = ContinuousActionVecMockEnv() + @pytest.mark.parametrize("device", get_default_devices()) + def test_transform_env(self, device): + base_env = ContinuousActionVecMockEnv(device=device) env = TransformedEnv( - env, + base_env, ClipTransform( in_keys=["observation", "reward"], in_keys_inv=["observation_orig"], @@ -395,6 +396,7 @@ def test_transform_env(self): ), ) r = env.rollout(3) + assert r.device == device assert (r["observation"] <= 0.1).all() assert (r["next", "observation"] <= 0.1).all() assert (r["next", "reward"] <= 0.1).all() @@ -426,7 +428,7 @@ def test_transform_env(self): high=-1.0, ) env = TransformedEnv( - env, + base_env, ClipTransform( in_keys=["observation", "reward"], in_keys_inv=["observation_orig"], @@ -436,7 +438,7 @@ def test_transform_env(self): ) check_env_specs(env) env = TransformedEnv( - env, + base_env, ClipTransform( in_keys=["observation", "reward"], in_keys_inv=["observation_orig"], @@ -446,7 +448,7 @@ def test_transform_env(self): ) check_env_specs(env) env = TransformedEnv( - env, + base_env, ClipTransform( in_keys=["observation", "reward"], in_keys_inv=["observation_orig"], @@ -455,6 +457,16 @@ def test_transform_env(self): ), ) check_env_specs(env) + env = TransformedEnv( + base_env, + ClipTransform( + in_keys=["observation", "reward"], + in_keys_inv=["observation_orig"], + low=-torch.ones(()), + high=1, + ), + ) + check_env_specs(env) def test_transform_inverse(self): t = ClipTransform( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 8644abc8cb9..7ddb44b584e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1160,10 +1160,10 @@ def check_val(val): high, high_eps, high_max = check_val(high) if low is not None and high is not None and low >= high: raise ValueError("`low` must be stricly lower than `high`.") - self.low = low + self.register_buffer("low", low) self.low_eps = low_eps self.low_min = -low_min - self.high = high + self.register_buffer("high", high) self.high_eps = high_eps self.high_max = high_max From b7e52992d4f5838e11e49b48e476555a81637359 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 8 Sep 2023 04:48:15 +0100 Subject: [PATCH 3/5] [CI] Move linux stable to GHA (#1503) --- .circleci/unittest/linux/scripts/run_all.sh | 17 +++++-- .github/workflows/test-linux-cpu.yml | 1 + .github/workflows/test-linux-gpu.yml | 1 + .github/workflows/test-linux-stable-gpu.yml | 50 +++++++++++++++++++++ .github/workflows/test-macos-cpu.yml | 1 + 5 files changed, 67 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/test-linux-stable-gpu.yml diff --git a/.circleci/unittest/linux/scripts/run_all.sh b/.circleci/unittest/linux/scripts/run_all.sh index f62562d4715..43edc768011 100755 --- a/.circleci/unittest/linux/scripts/run_all.sh +++ b/.circleci/unittest/linux/scripts/run_all.sh @@ -122,10 +122,21 @@ fi git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" -if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu + else + pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/$CU_VERSION + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$CU_VERSION + fi else - pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/$CU_VERSION + printf "Failed to install pytorch" + exit 1 fi # smoke test diff --git a/.github/workflows/test-linux-cpu.yml b/.github/workflows/test-linux-cpu.yml index d97d7807a59..f3490ee474b 100644 --- a/.github/workflows/test-linux-cpu.yml +++ b/.github/workflows/test-linux-cpu.yml @@ -34,6 +34,7 @@ jobs: # Set env vars from matrix export PYTHON_VERSION=${{ matrix.python_version }} export CU_VERSION="cpu" + export TORCH_VERSION=nightly echo "PYTHON_VERSION: $PYTHON_VERSION" echo "CU_VERSION: $CU_VERSION" diff --git a/.github/workflows/test-linux-gpu.yml b/.github/workflows/test-linux-gpu.yml index d576f813e4c..b5e1577d65f 100644 --- a/.github/workflows/test-linux-gpu.yml +++ b/.github/workflows/test-linux-gpu.yml @@ -39,6 +39,7 @@ jobs: # Commenting these out for now because the GPU test are not working inside docker export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }} export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}" + export TORCH_VERSION=nightly # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines #export CU_VERSION="cpu" diff --git a/.github/workflows/test-linux-stable-gpu.yml b/.github/workflows/test-linux-stable-gpu.yml new file mode 100644 index 00000000000..cf2681b10ff --- /dev/null +++ b/.github/workflows/test-linux-stable-gpu.yml @@ -0,0 +1,50 @@ +name: Unit-tests on Linux GPU, latest stable release + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + workflow_dispatch: + +env: + CHANNEL: "nightly" + +concurrency: + # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. + # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} + cancel-in-progress: true + +jobs: + tests: + strategy: + matrix: + python_version: ["3.9"] # "3.8", "3.9", "3.10", "3.11" + cuda_arch_version: ["11.8"] # "11.6", "11.7" + fail-fast: false + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + runner: linux.g5.4xlarge.nvidia.gpu + repository: pytorch/rl + docker-image: "nvidia/cuda:12.1.0-devel-ubuntu22.04" + gpu-arch-type: cuda + gpu-arch-version: ${{ matrix.cuda_arch_version }} + timeout: 90 + script: | + # Set env vars from matrix + export PYTHON_VERSION=${{ matrix.python_version }} + # Commenting these out for now because the GPU test are not working inside docker + export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }} + export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}" + export TORCH_VERSION=stable + # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines + #export CU_VERSION="cpu" + + echo "PYTHON_VERSION: $PYTHON_VERSION" + echo "CU_VERSION: $CU_VERSION" + + ## setup_env.sh + bash .circleci/unittest/linux/scripts/run_all.sh diff --git a/.github/workflows/test-macos-cpu.yml b/.github/workflows/test-macos-cpu.yml index c802ca86f19..278125230cf 100644 --- a/.github/workflows/test-macos-cpu.yml +++ b/.github/workflows/test-macos-cpu.yml @@ -33,6 +33,7 @@ jobs: export PYTHON_VERSION=${{ matrix.python_version }} export CU_VERSION="cpu" export SYSTEM_VERSION_COMPAT=0 + export TORCH_VERSION=nightly echo "PYTHON_VERSION: $PYTHON_VERSION" echo "CU_VERSION: $CU_VERSION" From c62781c673ac75b30973c79e08a0ebad0b3222ee Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Fri, 8 Sep 2023 09:22:33 +0100 Subject: [PATCH 4/5] [BugFix] Add `torch.no_grad()` for rendering in multiagent PPO tutorial (#1511) Signed-off-by: Matteo Bettini --- tutorials/sphinx-tutorials/multiagent_ppo.py | 30 +++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index c84be33a84e..4d35b18a360 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -718,13 +718,14 @@ # # .. code-block:: python # -# env.rollout( -# max_steps=max_steps, -# policy=policy, -# callback=lambda env, _: env.render(), -# auto_cast_to_device=True, -# break_when_any_done=False, -# ) +# with torch.no_grad(): +# env.rollout( +# max_steps=max_steps, +# policy=policy, +# callback=lambda env, _: env.render(), +# auto_cast_to_device=True, +# break_when_any_done=False, +# ) # # If you are running this in Google Colab, you can render the trained policy by running: # @@ -745,13 +746,14 @@ # def rendering_callback(env, td): # env.frames.append(Image.fromarray(env.render(mode="rgb_array"))) # env.frames = [] -# env.rollout( -# max_steps=max_steps, -# policy=policy, -# callback=rendering_callback, -# auto_cast_to_device=True, -# break_when_any_done=False, -# ) +# with torch.no_grad(): +# env.rollout( +# max_steps=max_steps, +# policy=policy, +# callback=rendering_callback, +# auto_cast_to_device=True, +# break_when_any_done=False, +# ) # env.frames[0].save( # f"{scenario_name}.gif", # save_all=True, From 6de66ed8fe02f205178abf35f5add3bfb932bf58 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 8 Sep 2023 07:33:51 -0400 Subject: [PATCH 5/5] fixes --- examples/dqn/dqn_atari.py | 63 ++++++++++++++++----------- examples/dqn/dqn_carpole.py | 36 ++++++++++----- torchrl/data/tensor_specs.py | 4 +- torchrl/envs/common.py | 3 +- torchrl/envs/gym_like.py | 2 +- torchrl/envs/transforms/transforms.py | 29 ++++++------ torchrl/record/loggers/csv.py | 18 ++++++-- 7 files changed, 96 insertions(+), 59 deletions(-) diff --git a/examples/dqn/dqn_atari.py b/examples/dqn/dqn_atari.py index 31b793e94f1..f8610c0277c 100644 --- a/examples/dqn/dqn_atari.py +++ b/examples/dqn/dqn_atari.py @@ -3,40 +3,36 @@ Deep Q-Learning Algorithm on Atari Environments. """ -import gym -import tqdm -import time import random +import time + +import gym +import numpy as np import torch.nn import torch.optim -import numpy as np +import tqdm from tensordict import TensorDict from torchrl.collectors import SyncDataCollector from torchrl.data import CompositeSpec, LazyMemmapStorage, TensorDictReplayBuffer -from torchrl.envs.libs.gym import GymWrapper from torchrl.envs import ( + CatFrames, default_info_dict_reader, - Resize, - VecNorm, + DoubleToFloat, + ExplorationType, GrayScale, + NoopResetEnv, + Resize, + RewardClipping, RewardSum, - CatFrames, + set_exploration_type, StepCounter, ToTensorImage, - DoubleToFloat, - RewardClipping, TransformedEnv, - NoopResetEnv, - ExplorationType, - set_exploration_type, -) -from torchrl.modules import ( - MLP, - ConvNet, - QValueActor, - EGreedyWrapper, + VecNorm, ) +from torchrl.envs.libs.gym import GymWrapper +from torchrl.modules import ConvNet, EGreedyWrapper, MLP, QValueActor from torchrl.objectives import DQNLoss, HardUpdate from torchrl.record.loggers import generate_exp_name, get_logger @@ -91,6 +87,7 @@ def make_env(env_name, device, is_test=False): # env.append_transform(VecNorm(in_keys=["pixels"])) return env + # ==================================================================== # Model utils # -------------------------------------------------------------------- @@ -137,6 +134,7 @@ def make_dqn_model(env_name): # Collector utils # -------------------------------------------------------------------- + def make_collector(env_name, policy, device): collector_class = SyncDataCollector collector = collector_class( @@ -151,15 +149,16 @@ def make_collector(env_name, policy, device): collector.set_seed(seed) return collector + # ==================================================================== # Collector and replay buffer utils # -------------------------------------------------------------------- def make_replay_buffer( - batch_size, - buffer_scratch_dir="/tmp/", - prefetch=3, + batch_size, + buffer_scratch_dir="/tmp/", + prefetch=3, ): replay_buffer = TensorDictReplayBuffer( pin_memory=False, @@ -173,6 +172,7 @@ def make_replay_buffer( ) return replay_buffer + # ==================================================================== # Discrete DQN Loss # -------------------------------------------------------------------- @@ -187,9 +187,12 @@ def make_loss_module(value_network): delay_value=True, ) dqn_loss.make_value_estimator(gamma=gamma) - targ_net_updater = HardUpdate(dqn_loss, value_network_update_interval=hard_update_freq) + targ_net_updater = HardUpdate( + dqn_loss, value_network_update_interval=hard_update_freq + ) return dqn_loss, targ_net_updater + # ==================================================================== # Other component utils # -------------------------------------------------------------------- @@ -234,7 +237,9 @@ def make_logger(backend="csv"): # Make the components model = make_dqn_model(env_name) - model_explore = EGreedyWrapper(model, annealing_num_steps=annealing_frames, eps_end=end_e).to(device) + model_explore = EGreedyWrapper( + model, annealing_num_steps=annealing_frames, eps_end=end_e + ).to(device) collector = make_collector(env_name, model_explore, device) replay_buffer = make_replay_buffer(batch_size) loss_module, target_net_updater = make_loss_module(model) @@ -254,8 +259,14 @@ def make_logger(backend="csv"): episode_rewards = data["next", "episode_reward"][data["next", "done"]] if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "done"]] - logger.log_scalar("reward_train", episode_rewards.mean().item(), collected_frames) - logger.log_scalar("episode_length_train", episode_length.sum().item() / len(episode_length), collected_frames) + logger.log_scalar( + "reward_train", episode_rewards.mean().item(), collected_frames + ) + logger.log_scalar( + "episode_length_train", + episode_length.sum().item() / len(episode_length), + collected_frames, + ) pbar.update(data.numel()) data = data.reshape(-1) diff --git a/examples/dqn/dqn_carpole.py b/examples/dqn/dqn_carpole.py index 3534563f8e0..ff30b26e434 100644 --- a/examples/dqn/dqn_carpole.py +++ b/examples/dqn/dqn_carpole.py @@ -2,17 +2,18 @@ DQN Benchmarks: CartPole-v1 """ -import tqdm import time + import torch.nn import torch.optim +import tqdm from tensordict import TensorDict from torchrl.collectors import SyncDataCollector from torchrl.data import CompositeSpec, LazyTensorStorage, TensorDictReplayBuffer +from torchrl.envs import DoubleToFloat, RewardSum, StepCounter, TransformedEnv from torchrl.envs.libs.gym import GymEnv -from torchrl.envs import RewardSum, DoubleToFloat, TransformedEnv, StepCounter +from torchrl.modules import EGreedyWrapper, MLP, QValueActor from torchrl.objectives import DQNLoss, HardUpdate -from torchrl.modules import MLP, QValueActor, EGreedyWrapper from torchrl.record.loggers import generate_exp_name, get_logger @@ -20,6 +21,7 @@ # Environment utils # -------------------------------------------------------------------- + def make_env(env_name="CartPole-v1", device="cpu"): env = GymEnv(env_name, device=device) env = TransformedEnv(env) @@ -28,6 +30,7 @@ def make_env(env_name="CartPole-v1", device="cpu"): env.append_transform(DoubleToFloat()) return env + # ==================================================================== # Model utils # -------------------------------------------------------------------- @@ -68,7 +71,7 @@ def make_dqn_model(env_name): device = "cpu" if not torch.cuda.is_available() else "cuda" env_name = "CartPole-v1" - total_frames = 500_000 + total_frames = 5_000 # 500_000 record_interval = 500_000 frames_per_batch = 10 num_updates = 1 @@ -87,10 +90,11 @@ def make_dqn_model(env_name): # Make the components model = make_dqn_model(env_name) - model_explore = EGreedyWrapper(model, annealing_num_steps=annealing_frames, eps_end=eps_end).to(device) + model_explore = EGreedyWrapper( + model, annealing_num_steps=annealing_frames, eps_end=eps_end + ).to(device) # Create the collector - collector_class = SyncDataCollector collector = SyncDataCollector( make_env(env_name, device), policy=model_explore, @@ -121,7 +125,9 @@ def make_dqn_model(env_name): delay_value=True, ) loss_module.make_value_estimator(gamma=gamma) - target_net_updater = HardUpdate(loss_module, value_network_update_interval=hard_update_freq) + target_net_updater = HardUpdate( + loss_module, value_network_update_interval=hard_update_freq + ) # Create the optimizer optimizer = torch.optim.Adam(loss_module.parameters(), lr=lr) @@ -138,12 +144,22 @@ def make_dqn_model(env_name): for i, data in enumerate(collector): # Train loging - logger.log_scalar("q_values", (data["action_value"]*data["action"]).sum().item() / frames_per_batch, collected_frames) + logger.log_scalar( + "q_values", + (data["action_value"] * data["action"]).sum().item() / frames_per_batch, + collected_frames, + ) episode_rewards = data["next", "episode_reward"][data["next", "done"]] if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "done"]] - logger.log_scalar("reward_train", episode_rewards.mean().item(), collected_frames) - logger.log_scalar("episode_length_train", episode_length.sum().item() / len(episode_length), collected_frames) + logger.log_scalar( + "reward_train", episode_rewards.mean().item(), collected_frames + ) + logger.log_scalar( + "episode_length_train", + episode_length.sum().item() / len(episode_length), + collected_frames, + ) pbar.update(data.numel()) data = data.reshape(-1) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index a4aa0278c55..559a34c1df1 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -555,10 +555,10 @@ def encode( val = torch.tensor(val, device=self.device, dtype=self.dtype) else: val = torch.as_tensor(val, dtype=self.dtype) - if val != self.shape: + if val.shape != self.shape: # if val.shape[-len(self.shape) :] != self.shape: # option 1: add a singleton dim at the end - if val == self.shape and self.shape[-1] == 1: + if val.shape == self.shape and self.shape[-1] == 1: val = val.unsqueeze(-1) else: try: diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index ad6df69cf78..5ecdf148238 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -231,6 +231,7 @@ def __init__( self.__dict__["_done_keys"] = None self.__dict__["_reward_keys"] = None self.__dict__["_action_keys"] = None + self.__dict__["_batch_size"] = None if device is not None: self.__dict__["_device"] = torch.device(device) output_spec = self.__dict__.get("_output_spec", None) @@ -320,7 +321,7 @@ def run_type_checks(self, run_type_checks: bool) -> None: @property def batch_size(self) -> torch.Size: - _batch_size = getattr(self, "_batch_size", None) + _batch_size = self.__dict__["_batch_size"] if _batch_size is None: _batch_size = self._batch_size = torch.Size([]) return _batch_size diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 43eda6128bf..289bb731278 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -153,7 +153,7 @@ def read_reward(self, reward): reward (torch.Tensor or TensorDict): reward to be mapped. """ - return self.reward_spec.encode(reward) + return self.reward_spec.encode(reward, ignore_device=True) def read_obs( self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray] diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7ddb44b584e..e86558e1b5f 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -730,23 +730,24 @@ def insert_transform(self, index: int, transform: Transform) -> None: self._erase_metadata() def __getattr__(self, attr: str) -> Any: - if attr in self.__dir__(): + try: return super().__getattr__( attr ) # make sure that appropriate exceptions are raised - elif attr.startswith("__"): - raise AttributeError( - "passing built-in private methods is " - f"not permitted with type {type(self)}. " - f"Got attribute {attr}." - ) - elif "base_env" in self.__dir__(): - base_env = self.__getattr__("base_env") - return getattr(base_env, attr) + except Exception as err: + if attr.startswith("__"): + raise AttributeError( + "passing built-in private methods is " + f"not permitted with type {type(self)}. " + f"Got attribute {attr}." + ) + elif "base_env" in self.__dir__(): + base_env = self.__getattr__("base_env") + return getattr(base_env, attr) raise AttributeError( f"env not set in {self.__class__.__name__}, cannot access {attr}" - ) + ) from err def __repr__(self) -> str: env_str = indent(f"env={self.base_env}", 4 * " ") @@ -3988,9 +3989,8 @@ def _step( for in_key, out_key in zip(self.in_keys, self.out_keys): if in_key in next_tensordict.keys(include_nested=True): reward = next_tensordict.get(in_key) - if out_key not in tensordict.keys(True): - tensordict.set(out_key, torch.zeros_like(reward)) - next_tensordict.set(out_key, tensordict.get(out_key) + reward) + prev_reward = tensordict.get(out_key, 0.0) + next_tensordict.set(out_key, prev_reward + reward) elif not self.missing_tolerance: raise KeyError(f"'{in_key}' not found in tensordict {tensordict}") return next_tensordict @@ -4154,7 +4154,6 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: - tensordict = tensordict.clone(False) step_count = tensordict.get(self.step_count_key) next_step_count = step_count + 1 next_tensordict.set(self.step_count_key, next_step_count) diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index ea94444ddc2..b7157cc259a 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -26,14 +26,18 @@ def __init__(self, log_dir: str): os.makedirs(os.path.join(self.log_dir, "videos")) os.makedirs(os.path.join(self.log_dir, "texts")) + self.files = {} + def add_scalar(self, name: str, value: float, global_step: Optional[int] = None): if global_step is None: global_step = len(self.scalars[name]) value = float(value) self.scalars[name].append((global_step, value)) filepath = os.path.join(self.log_dir, "scalars", "".join([name, ".csv"])) - with open(filepath, "a") as fd: - fd.write(",".join([str(global_step), str(value)]) + "\n") + if filepath not in self.files: + self.files[filepath] = open(filepath, "a") + fd = self.files[filepath] + fd.write(",".join([str(global_step), str(value)]) + "\n") def add_video(self, tag, vid_tensor, global_step: Optional[int] = None, **kwargs): if global_step is None: @@ -53,12 +57,18 @@ def add_text(self, tag, text, global_step: Optional[int] = None): filepath = os.path.join( self.log_dir, "texts", "".join([tag, str(global_step)]) + ".txt" ) - with open(filepath, "w+") as f: - f.writelines(text) + if filepath not in self.files: + self.files[filepath] = open(filepath, "w+") + fd = self.files[filepath] + fd.writelines(text) def __repr__(self) -> str: return f"CSVExperiment(log_dir={self.log_dir})" + def __del__(self): + for val in getattr(self, "files", {}).values(): + val.close() + class CSVLogger(Logger): """A minimal-dependecy CSV-logger.