Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 22, 2023
1 parent 343214f commit 7b6beca
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 175 deletions.
11 changes: 4 additions & 7 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
get_default_devices,
)
from mocking_classes import ContinuousActionConvMockEnv
from tensordict.nn import get_functional, NormalParamExtractor, TensorDictModule
from tensordict.nn import NormalParamExtractor, TensorDictModule
from tensordict.nn.utils import Buffer

# from torchrl.data.postprocs.utils import expand_as_right
Expand Down Expand Up @@ -6811,20 +6811,20 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est
advantage = GAE(
gamma=gamma,
lmbda=0.9,
value_network=get_functional(value_net),
value_network=value_net,
differentiable=gradient_mode,
)
elif advantage == "td":
advantage = TD1Estimator(
gamma=gamma,
value_network=get_functional(value_net),
value_network=value_net,
differentiable=gradient_mode,
)
elif advantage == "td_lambda":
advantage = TDLambdaEstimator(
gamma=0.9,
lmbda=0.9,
value_network=get_functional(value_net),
value_network=value_net,
differentiable=gradient_mode,
)
elif advantage is None:
Expand Down Expand Up @@ -9633,9 +9633,6 @@ def test_tdlambda_tensor_gamma(self, device, gamma, lmbda, N, T, has_done):
next_state_value = torch.randn(*N, T, 1, device=device)

gamma_tensor = torch.full((*N, T, 1), gamma, device=device)
# if len(N) == 2:
# print(terminated[4, 0, -10:])
# print(done[4, 0, -10:])
v1 = vec_td_lambda_advantage_estimate(
gamma,
lmbda,
Expand Down
3 changes: 2 additions & 1 deletion torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ def _compare_and_expand(param):
# otherwise they will appear twice in parameters
p = TensorDict.from_module(module)
with params.detach().to("meta").to_module(module):
setattr(self, module_name, deepcopy(module))
# avoid buffers and params being exposed
self.__dict__["module_name"] = deepcopy(module)
assert (p == TensorDict.from_module(module)).all()

name_params_target = "target_" + module_name
Expand Down
21 changes: 4 additions & 17 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,15 @@
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,vmap_func,
_GAMMA_LMBDA_DEPREC_WARNING,
default_value_kwargs,
distance_loss,
ValueEstimators,
_vmap_func,
)

from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator

try:
try:
from torch import vmap
except ImportError:
from functorch import vmap

_has_functorch = True
err = ""
except ImportError as err:
_has_functorch = False
FUNCTORCH_ERROR = err


class CQLLoss(LossModule):
"""TorchRL implementation of the continuous CQL loss.
Expand Down Expand Up @@ -267,8 +256,6 @@ def __init__(
priority_key: str = None,
) -> None:
self._out_keys = None
if not _has_functorch:
raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR
super().__init__()
self._set_deprecated_ctor_keys(priority_key=priority_key)

Expand Down Expand Up @@ -348,8 +335,8 @@ def __init__(
torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)),
)

self._vmap_qvalue_networkN0 = vmap_func(self.qvalue_network, (None, 0))
self._vmap_qvalue_network00 = vmap_func(self.qvalue_network)
self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))
self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network)

@property
def target_entropy(self):
Expand Down
22 changes: 6 additions & 16 deletions torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,13 @@
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.objectives import default_value_kwargs, distance_loss, ValueEstimators
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import _cache_values, _GAMMA_LMBDA_DEPREC_WARNING
from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
_vmap_func,
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator

try:
try:
from torch import vmap
except ImportError:
from functorch import vmap

FUNCTORCH_ERR = ""
_has_functorch = True
except ImportError as err:
FUNCTORCH_ERR = str(err)
_has_functorch = False


class REDQLoss_deprecated(LossModule):
"""REDQ Loss module.
Expand Down Expand Up @@ -149,8 +141,6 @@ def __init__(
):
self._in_keys = None
self._out_keys = None
if not _has_functorch:
raise ImportError("Failed to import functorch.") from FUNCTORCH_ERR
super().__init__()
self._set_deprecated_ctor_keys(priority_key=priority_key)

Expand Down Expand Up @@ -208,7 +198,7 @@ def __init__(
self.target_entropy_buffer = None
self.gSDE = gSDE

self._vmap_qvalue_networkN0 = vmap_func(self.qvalue_network, (None, 0))
self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))

if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
Expand Down
19 changes: 3 additions & 16 deletions torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,16 @@

from torchrl.modules import ProbabilisticActor
from torchrl.objectives.common import LossModule

from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
default_value_kwargs,
distance_loss,
ValueEstimators,
vmap_func,
_vmap_func,
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator

try:
try:
from torch import vmap
except ImportError:
from functorch import vmap

_has_functorch = True
err = ""
except ImportError as err:
_has_functorch = False
FUNCTORCH_ERROR = err


class IQLLoss(LossModule):
r"""TorchRL implementation of the IQL loss.
Expand Down Expand Up @@ -249,8 +238,6 @@ def __init__(
) -> None:
self._in_keys = None
self._out_keys = None
if not _has_functorch:
raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR
super().__init__()
self._set_deprecated_ctor_keys(priority=priority_key)

Expand Down Expand Up @@ -299,7 +286,7 @@ def __init__(
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
self._vmap_qvalue_networkN0 = vmap_func(self.qvalue_network, (None, 0))
self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))

@property
def device(self) -> torch.device:
Expand Down
16 changes: 7 additions & 9 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,7 @@ def __init__(
self._in_keys = None
self._out_keys = None
super().__init__()
self.convert_to_functional(
actor, "actor", funs_to_decorate=["forward", "get_dist"]
)
self.convert_to_functional(actor, "actor")
if separate_losses:
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
Expand Down Expand Up @@ -374,7 +372,8 @@ def _log_weight(
f"tensordict stored {self.tensor_keys.action} requires grad."
)

dist = self.actor.get_dist(tensordict, params=self.actor_params)
with self.actor_params.to_module(self.actor):
dist = self.actor.get_dist(tensordict)
log_prob = dist.log_prob(action)

prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
Expand All @@ -400,10 +399,8 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
f"can be used for the value loss."
)

state_value_td = self.critic(
tensordict,
params=self.critic_params,
)
with self.critic_params.to_module(self.critic):
state_value_td = self.critic(tensordict)

try:
state_value = state_value_td.get(self.tensor_keys.value)
Expand Down Expand Up @@ -848,7 +845,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
neg_loss = log_weight.exp() * advantage

previous_dist = self.actor.build_dist_from_params(tensordict)
current_dist = self.actor.get_dist(tensordict, params=self.actor_params)
with self.actor_params.to_module(self.actor):
current_dist = self.actor.get_dist(tensordict)
try:
kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist)
except NotImplementedError:
Expand Down
21 changes: 4 additions & 17 deletions torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,17 @@
from torchrl.data import CompositeSpec
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.objectives.common import LossModule

from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
default_value_kwargs,
distance_loss,
ValueEstimators,
_vmap_func,
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator

try:
try:
from torch import vmap
except ImportError:
from functorch import vmap

FUNCTORCH_ERR = ""
_has_functorch = True
except ImportError as err:
FUNCTORCH_ERR = str(err)
_has_functorch = False


class REDQLoss(LossModule):
"""REDQ Loss module.
Expand Down Expand Up @@ -265,8 +255,6 @@ def __init__(
priority_key: str = None,
separate_losses: bool = False,
):
if not _has_functorch:
raise ImportError("Failed to import functorch.") from FUNCTORCH_ERR

super().__init__()
self._in_keys = None
Expand All @@ -276,7 +264,6 @@ def __init__(
actor_network,
"actor_network",
create_target_params=self.delay_actor,
funs_to_decorate=["forward", "get_dist_params"],
)

# let's make sure that actor_network has `return_log_prob` set to True
Expand Down Expand Up @@ -331,8 +318,8 @@ def __init__(
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma

self._vmap_qvalue_network00 = vmap(self.qvalue_network)
self._vmap_getdist = vmap(self.actor_network.get_dist_params)
self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network)
self._vmap_getdist = _vmap_func(self.actor_network, func="get_dist_params")

@property
def target_entropy(self):
Expand Down
12 changes: 4 additions & 8 deletions torchrl/objectives/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
advantage = tensordict.get(self.tensor_keys.advantage)

# compute log-prob
tensordict = self.actor_network(
tensordict,
params=self.actor_network_params,
)
with self.actor_network_params.to_module(self.actor_network):
tensordict = self.actor_network(tensordict)

log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
if log_prob.shape == advantage.shape[:-1]:
Expand All @@ -305,10 +303,8 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
try:
target_return = tensordict.get(self.tensor_keys.value_target)
tensordict_select = tensordict.select(*self.critic.in_keys)
state_value = self.critic(
tensordict_select,
params=self.critic_params,
).get(self.tensor_keys.value)
with self.critic_params.to_module(self.critic):
state_value = self.critic(tensordict_select).get(self.tensor_keys.value)
loss_value = distance_loss(
target_return,
state_value,
Expand Down
Loading

0 comments on commit 7b6beca

Please sign in to comment.