Skip to content

Commit

Permalink
Merge pull request #1 from vmoens/refactor_tcq
Browse files Browse the repository at this point in the history
Refactor TQC
  • Loading branch information
maxweissenbacher authored Nov 10, 2023
2 parents 6c80564 + ce69631 commit 2e56b5b
Show file tree
Hide file tree
Showing 13 changed files with 1,119 additions and 184 deletions.
1 change: 1 addition & 0 deletions .github/unittest/linux_libs/scripts_d4rl/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ dependencies:
- pyyaml
- scipy
- hydra-core
- cython<3
1 change: 1 addition & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ We also give users the ability to compose a replay buffer using the following co
Writer
RoundRobinWriter
TensorDictRoundRobinWriter
TensorDictMaxValueWriter

Storage choice is very influential on replay buffer sampling latency, especially in distributed reinforcement learning settings with larger data volumes.
:class:`LazyMemmapStorage` is highly advised in distributed settings with shared storage due to the lower serialisation cost of MemmapTensors as well as the ability to specify file storage locations for improved node failure recovery.
Expand Down
13 changes: 11 additions & 2 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ CQL
CQLLoss

DT
----
--

.. autosummary::
:toctree: generated/
Expand All @@ -148,14 +148,23 @@ DT
OnlineDTLoss

TD3
----
---

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

TD3Loss

TQC
---

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

TQCLoss

PPO
---

Expand Down
7 changes: 5 additions & 2 deletions examples/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"""

import time

import hydra
import numpy as np
import torch
Expand Down Expand Up @@ -57,7 +58,7 @@ def main(cfg: "DictConfig"): # noqa: F821
exp_name = generate_exp_name("SAC", cfg.env.exp_name)
logger = None
# TO-DO: Add logging back in before pushing to git repo
#if cfg.logger.backend:
# if cfg.logger.backend:
# logger = get_logger(
# logger_type=cfg.logger.backend,
# logger_name="sac_logging/wandb",
Expand Down Expand Up @@ -190,7 +191,9 @@ def main(cfg: "DictConfig"): # noqa: F821
episode_length
)
if collected_frames >= init_random_frames:
metrics_to_log["train/critic_loss"] = losses.get("loss_critic").mean().item()
metrics_to_log["train/critic_loss"] = (
losses.get("loss_critic").mean().item()
)
metrics_to_log["train/actor_loss"] = losses.get("loss_actor").mean().item()
metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha").mean().item()
metrics_to_log["train/alpha"] = loss_td["alpha"].item()
Expand Down
204 changes: 31 additions & 173 deletions examples/tqc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,27 @@

import tempfile
from contextlib import nullcontext
from typing import Tuple

import torch
import numpy as np
from tensordict.nn import InteractionType, TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from tensordict.tensordict import TensorDict, TensorDictBase
from torch import nn, optim
from torchrl.collectors import SyncDataCollector
from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
from torchrl.data import (
CompositeSpec,
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.envs import Compose, DoubleToFloat, EnvCreator, ParallelEnv, TransformedEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.transforms import InitTracker, RewardSum, StepCounter
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import MLP, ProbabilisticActor, ValueOperator, ActorCriticWrapper
from torchrl.modules import ActorCriticWrapper, MLP, ProbabilisticActor, ValueOperator
from torchrl.modules.distributions import TanhNormal
from torchrl.objectives import SoftUpdate
from torchrl.data import CompositeSpec
from torchrl.objectives.common import LossModule
from torchrl.objectives import SoftUpdate, TQCLoss
from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
Expand All @@ -30,8 +34,6 @@
ValueEstimators,
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
from tensordict.tensordict import TensorDict, TensorDictBase
from typing import Tuple


# ====================================================================
Expand Down Expand Up @@ -100,17 +102,17 @@ def make_collector(cfg, train_env, actor_model_explore):


def make_replay_buffer(
batch_size,
prb=False,
buffer_size=1_000_000,
buffer_scratch_dir=None,
device="cpu",
prefetch=3,
batch_size,
prb=False,
buffer_size=1_000_000,
buffer_scratch_dir=None,
device="cpu",
prefetch=3,
):
with (
tempfile.TemporaryDirectory()
if buffer_scratch_dir is None
else nullcontext(buffer_scratch_dir)
tempfile.TemporaryDirectory()
if buffer_scratch_dir is None
else nullcontext(buffer_scratch_dir)
) as scratch_dir:
if prb:
replay_buffer = TensorDictPrioritizedReplayBuffer(
Expand Down Expand Up @@ -155,13 +157,15 @@ def __init__(self, cfg):
}
for i in range(cfg.network.n_nets):
net = MLP(**qvalue_net_kwargs)
self.add_module(f'critic_net_{i}', net)
self.add_module(f"critic_net_{i}", net)
self.nets.append(net)

def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:
if len(inputs) > 1:
inputs = (torch.cat([*inputs], -1),)
quantiles = torch.stack(tuple(net(*inputs) for net in self.nets), dim=-2) # batch x n_nets x n_quantiles
quantiles = torch.stack(
tuple(net(*inputs) for net in self.nets), dim=-2
) # batch x n_nets x n_quantiles
return quantiles


Expand Down Expand Up @@ -239,172 +243,26 @@ def make_tqc_agent(cfg, train_env, eval_env, device):
return model, model[0]


# ====================================================================
# Quantile Huber Loss
# -------------------


def quantile_huber_loss_f(quantiles, samples):
"""
Quantile Huber loss from the original PyTorch TQC implementation.
See: https://github.com/SamsungLabs/tqc_pytorch/blob/master/tqc/functions.py
quantiles is assumed to be of shape [batch size, n_nets, n_quantiles]
samples is assumed to be of shape [batch size, n_samples]
Arbitrary batch sizes are allowed.
"""
pairwise_delta = samples[..., None, None, :] - quantiles[..., None] # batch x n_nets x n_quantiles x n_samples
abs_pairwise_delta = torch.abs(pairwise_delta)
huber_loss = torch.where(abs_pairwise_delta > 1,
abs_pairwise_delta - 0.5,
pairwise_delta ** 2 * 0.5)
n_quantiles = quantiles.shape[-1]
tau = torch.arange(n_quantiles, device=quantiles.device).float() / n_quantiles + 1 / 2 / n_quantiles
loss = (torch.abs(tau[..., None, :, None] - (pairwise_delta < 0).float()) * huber_loss).mean()
return loss


# ====================================================================
# TQC Loss
# --------

class TQCLoss(LossModule):
def __init__(
self,
actor_network,
qvalue_network,
gamma,
top_quantiles_to_drop,
alpha_init,
device
):
super().__init__()

self.convert_to_functional(
actor_network,
"actor",
create_target_params=False,
funs_to_decorate=["forward", "get_dist"],
)

self.convert_to_functional(
qvalue_network,
"critic",
create_target_params=True # Create a target critic network
)

self.device = device
self.log_alpha = torch.tensor([np.log(alpha_init)], requires_grad=True, device=self.device)
self.gamma = gamma
self.top_quantiles_to_drop = top_quantiles_to_drop

# Compute target entropy
action_spec = getattr(self.actor, "spec", None)
if action_spec is None:
print("Could not deduce action spec from actor network.")
if not isinstance(action_spec, CompositeSpec):
action_spec = CompositeSpec({"action": action_spec})
action_container_len = len(action_spec.shape)
self.target_entropy = -float(action_spec["action"].shape[action_container_len:].numel())

def value_loss(self, tensordict):
td_next = tensordict.get("next")
reward = td_next.get("reward")
not_done = tensordict.get("done").logical_not()
alpha = torch.exp(self.log_alpha)

# Q-loss
with torch.no_grad():
# get policy action
self.actor(td_next, params=self.actor_params)
self.critic(td_next, params=self.target_critic_params)

next_log_pi = td_next.get("sample_log_prob")
next_log_pi = torch.unsqueeze(next_log_pi, dim=-1)

# compute and cut quantiles at the next state
next_z = td_next.get("state_action_value")
sorted_z, _ = torch.sort(next_z.reshape(*tensordict.batch_size, -1))
sorted_z_part = sorted_z[..., :-self.top_quantiles_to_drop]

# compute target
# --- Note ---
# This is computed manually here, since the built-in value estimators in the library
# currently do not support a critic of a shape different from the reward.
# ------------
target = reward + not_done * self.gamma * (sorted_z_part - alpha * next_log_pi)

self.critic(tensordict, params=self.critic_params)
cur_z = tensordict.get("state_action_value")
critic_loss = quantile_huber_loss_f(cur_z, target)
return critic_loss

def actor_loss(self, tensordict):
alpha = torch.exp(self.log_alpha)
self.actor(tensordict, params=self.actor_params)
self.critic(tensordict, params=self.critic_params)
new_log_pi = tensordict.get("sample_log_prob")
actor_loss = (alpha * new_log_pi - tensordict.get("state_action_value").mean(-1).mean(-1, keepdim=True)).mean()
return actor_loss, new_log_pi

def alpha_loss(self, log_prob):
alpha_loss = -self.log_alpha * (log_prob + self.target_entropy).detach().mean()
return alpha_loss

def entropy(self, tensordict):
with set_exploration_type(ExplorationType.RANDOM):
dist = self.actor.get_dist(
tensordict,
params=self.actor_params,
)
a_reparm = dist.rsample()
log_prob = dist.log_prob(a_reparm).detach()
entropy = -log_prob.mean()
return entropy

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
alpha = torch.exp(self.log_alpha)
critic_loss = self.value_loss(tensordict)
actor_loss, log_prob = self.actor_loss(tensordict) # Compute actor loss AFTER critic loss
alpha_loss = self.alpha_loss(log_prob)
entropy = self.entropy(tensordict)

return TensorDict(
{
"loss_critic": critic_loss,
"loss_actor": actor_loss,
"loss_alpha": alpha_loss,
"alpha": alpha,
"entropy": entropy,
},
batch_size=[]
)

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
"""
This is a dummy function, which simply checks if the value type is TD0 and raises
an error if the value type is different. As of writing of this, the value estimators
in the library do not support a critic shape different from the reward state, which
is however necessary by construction for TQC. Therefore, this function does not
actually construct a value estimator, and the value is estimated "by hand" in the
value_loss function above.
"""
if value_type is not ValueEstimators.TD0:
raise NotImplementedError(f"Value type {value_type} is not currently implemented.")


def make_loss_module(cfg, model):
"""Make loss module and target network updater."""
# Create TQC loss
top_quantiles_to_drop = (
cfg.network.top_quantiles_to_drop_per_net * cfg.network.n_nets
)
loss_module = TQCLoss(
actor_network=model[0],
qvalue_network=model[1],
device=cfg.network.device,
gamma=cfg.optim.gamma,
top_quantiles_to_drop=cfg.network.top_quantiles_to_drop_per_net * cfg.network.n_nets,
alpha_init=cfg.optim.alpha_init
top_quantiles_to_drop=top_quantiles_to_drop,
alpha_init=cfg.optim.alpha_init,
)
loss_module.make_value_estimator(
value_type=ValueEstimators.TD0, gamma=cfg.optim.gamma
)
loss_module.make_value_estimator(value_type=ValueEstimators.TD0)

# Define Target Network Updater
target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak)
Expand Down
Loading

0 comments on commit 2e56b5b

Please sign in to comment.