diff --git a/test/test_cost.py b/test/test_cost.py index a0054675e1e..8d7fd51c53b 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -47,7 +47,7 @@ get_default_devices, ) from mocking_classes import ContinuousActionConvMockEnv -from tensordict.nn import get_functional, NormalParamExtractor, TensorDictModule +from tensordict.nn import NormalParamExtractor, TensorDictModule from tensordict.nn.utils import Buffer # from torchrl.data.postprocs.utils import expand_as_right @@ -6811,20 +6811,20 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est advantage = GAE( gamma=gamma, lmbda=0.9, - value_network=get_functional(value_net), + value_network=value_net, differentiable=gradient_mode, ) elif advantage == "td": advantage = TD1Estimator( gamma=gamma, - value_network=get_functional(value_net), + value_network=value_net, differentiable=gradient_mode, ) elif advantage == "td_lambda": advantage = TDLambdaEstimator( gamma=0.9, lmbda=0.9, - value_network=get_functional(value_net), + value_network=value_net, differentiable=gradient_mode, ) elif advantage is None: @@ -9633,9 +9633,6 @@ def test_tdlambda_tensor_gamma(self, device, gamma, lmbda, N, T, has_done): next_state_value = torch.randn(*N, T, 1, device=device) gamma_tensor = torch.full((*N, T, 1), gamma, device=device) - # if len(N) == 2: - # print(terminated[4, 0, -10:]) - # print(done[4, 0, -10:]) v1 = vec_td_lambda_advantage_estimate( gamma, lmbda, diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index b941b01a3f0..5286aa31f4a 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -290,7 +290,8 @@ def _compare_and_expand(param): # otherwise they will appear twice in parameters p = TensorDict.from_module(module) with params.detach().to("meta").to_module(module): - setattr(self, module_name, deepcopy(module)) + # avoid buffers and params being exposed + self.__dict__["module_name"] = deepcopy(module) assert (p == TensorDict.from_module(module)).all() name_params_target = "target_" + module_name diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 7464048590f..d29c82c80aa 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -26,26 +26,15 @@ from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_WARNING,vmap_func, + _GAMMA_LMBDA_DEPREC_WARNING, default_value_kwargs, distance_loss, ValueEstimators, + _vmap_func, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator -try: - try: - from torch import vmap - except ImportError: - from functorch import vmap - - _has_functorch = True - err = "" -except ImportError as err: - _has_functorch = False - FUNCTORCH_ERROR = err - class CQLLoss(LossModule): """TorchRL implementation of the continuous CQL loss. @@ -267,8 +256,6 @@ def __init__( priority_key: str = None, ) -> None: self._out_keys = None - if not _has_functorch: - raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR super().__init__() self._set_deprecated_ctor_keys(priority_key=priority_key) @@ -348,8 +335,8 @@ def __init__( torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)), ) - self._vmap_qvalue_networkN0 = vmap_func(self.qvalue_network, (None, 0)) - self._vmap_qvalue_network00 = vmap_func(self.qvalue_network) + self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) + self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network) @property def target_entropy(self): diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 5d94f5bbafa..947a7574967 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -21,21 +21,13 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives import default_value_kwargs, distance_loss, ValueEstimators from torchrl.objectives.common import LossModule -from torchrl.objectives.utils import _cache_values, _GAMMA_LMBDA_DEPREC_WARNING +from torchrl.objectives.utils import ( + _cache_values, + _GAMMA_LMBDA_DEPREC_WARNING, + _vmap_func, +) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator -try: - try: - from torch import vmap - except ImportError: - from functorch import vmap - - FUNCTORCH_ERR = "" - _has_functorch = True -except ImportError as err: - FUNCTORCH_ERR = str(err) - _has_functorch = False - class REDQLoss_deprecated(LossModule): """REDQ Loss module. @@ -149,8 +141,6 @@ def __init__( ): self._in_keys = None self._out_keys = None - if not _has_functorch: - raise ImportError("Failed to import functorch.") from FUNCTORCH_ERR super().__init__() self._set_deprecated_ctor_keys(priority_key=priority_key) @@ -208,7 +198,7 @@ def __init__( self.target_entropy_buffer = None self.gSDE = gSDE - self._vmap_qvalue_networkN0 = vmap_func(self.qvalue_network, (None, 0)) + self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) if gamma is not None: warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 67d926658dd..50a06ff19ef 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -14,27 +14,16 @@ from torchrl.modules import ProbabilisticActor from torchrl.objectives.common import LossModule + from torchrl.objectives.utils import ( _GAMMA_LMBDA_DEPREC_WARNING, default_value_kwargs, distance_loss, ValueEstimators, - vmap_func, + _vmap_func, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator -try: - try: - from torch import vmap - except ImportError: - from functorch import vmap - - _has_functorch = True - err = "" -except ImportError as err: - _has_functorch = False - FUNCTORCH_ERROR = err - class IQLLoss(LossModule): r"""TorchRL implementation of the IQL loss. @@ -249,8 +238,6 @@ def __init__( ) -> None: self._in_keys = None self._out_keys = None - if not _has_functorch: - raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR super().__init__() self._set_deprecated_ctor_keys(priority=priority_key) @@ -299,7 +286,7 @@ def __init__( if gamma is not None: warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma - self._vmap_qvalue_networkN0 = vmap_func(self.qvalue_network, (None, 0)) + self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) @property def device(self) -> torch.device: diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index e576ca33c1c..f0ac6f9104c 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -271,9 +271,7 @@ def __init__( self._in_keys = None self._out_keys = None super().__init__() - self.convert_to_functional( - actor, "actor", funs_to_decorate=["forward", "get_dist"] - ) + self.convert_to_functional(actor, "actor") if separate_losses: # we want to make sure there are no duplicates in the params: the # params of critic must be refs to actor if they're shared @@ -374,7 +372,8 @@ def _log_weight( f"tensordict stored {self.tensor_keys.action} requires grad." ) - dist = self.actor.get_dist(tensordict, params=self.actor_params) + with self.actor_params.to_module(self.actor): + dist = self.actor.get_dist(tensordict) log_prob = dist.log_prob(action) prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob) @@ -400,10 +399,8 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: f"can be used for the value loss." ) - state_value_td = self.critic( - tensordict, - params=self.critic_params, - ) + with self.critic_params.to_module(self.critic): + state_value_td = self.critic(tensordict) try: state_value = state_value_td.get(self.tensor_keys.value) @@ -848,7 +845,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: neg_loss = log_weight.exp() * advantage previous_dist = self.actor.build_dist_from_params(tensordict) - current_dist = self.actor.get_dist(tensordict, params=self.actor_params) + with self.actor_params.to_module(self.actor): + current_dist = self.actor.get_dist(tensordict) try: kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) except NotImplementedError: diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index dd64a4bc033..3ba37c52b7b 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -18,27 +18,17 @@ from torchrl.data import CompositeSpec from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives.common import LossModule + from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_WARNING, default_value_kwargs, distance_loss, ValueEstimators, + _vmap_func, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator -try: - try: - from torch import vmap - except ImportError: - from functorch import vmap - - FUNCTORCH_ERR = "" - _has_functorch = True -except ImportError as err: - FUNCTORCH_ERR = str(err) - _has_functorch = False - class REDQLoss(LossModule): """REDQ Loss module. @@ -265,8 +255,6 @@ def __init__( priority_key: str = None, separate_losses: bool = False, ): - if not _has_functorch: - raise ImportError("Failed to import functorch.") from FUNCTORCH_ERR super().__init__() self._in_keys = None @@ -276,7 +264,6 @@ def __init__( actor_network, "actor_network", create_target_params=self.delay_actor, - funs_to_decorate=["forward", "get_dist_params"], ) # let's make sure that actor_network has `return_log_prob` set to True @@ -331,8 +318,8 @@ def __init__( warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma - self._vmap_qvalue_network00 = vmap(self.qvalue_network) - self._vmap_getdist = vmap(self.actor_network.get_dist_params) + self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network) + self._vmap_getdist = _vmap_func(self.actor_network, func="get_dist_params") @property def target_entropy(self): diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 93910f1eebf..0782187ec04 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -285,10 +285,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: advantage = tensordict.get(self.tensor_keys.advantage) # compute log-prob - tensordict = self.actor_network( - tensordict, - params=self.actor_network_params, - ) + with self.actor_network_params.to_module(self.actor_network): + tensordict = self.actor_network(tensordict) log_prob = tensordict.get(self.tensor_keys.sample_log_prob) if log_prob.shape == advantage.shape[:-1]: @@ -305,10 +303,8 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: try: target_return = tensordict.get(self.tensor_keys.value_target) tensordict_select = tensordict.select(*self.critic.in_keys) - state_value = self.critic( - tensordict_select, - params=self.critic_params, - ).get(self.tensor_keys.value) + with self.critic_params.to_module(self.critic): + state_value = self.critic(tensordict_select).get(self.tensor_keys.value) loss_value = distance_loss( target_return, state_value, diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index cdd869c714a..b33b9b522e8 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -22,27 +22,17 @@ from torchrl.modules import ProbabilisticActor from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.common import LossModule + from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_WARNING, default_value_kwargs, distance_loss, ValueEstimators, + _vmap_func, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator -try: - try: - from torch import vmap - except ImportError: - from functorch import vmap - - _has_functorch = True - err = "" -except ImportError as err: - _has_functorch = False - FUNCTORCH_ERROR = err - def _delezify(func): @wraps(func) @@ -293,8 +283,6 @@ def __init__( ) -> None: self._in_keys = None self._out_keys = None - if not _has_functorch: - raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR super().__init__() self._set_deprecated_ctor_keys(priority_key=priority_key) @@ -388,9 +376,9 @@ def __init__( if gamma is not None: warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma - self._vmap_qnetworkN0 = vmap(self.qvalue_network, (None, 0)) + self._vmap_qnetworkN0 = _vmap_func(self.qvalue_network, (None, 0)) if self._version == 1: - self._vmap_qnetwork00 = vmap(qvalue_network) + self._vmap_qnetwork00 = _vmap_func(qvalue_network) @property def target_entropy_buffer(self): @@ -588,11 +576,10 @@ def _cached_detached_qvalue_params(self): def _actor_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: - with set_exploration_type(ExplorationType.RANDOM): - dist = self.actor_network.get_dist( - tensordict, - params=self.actor_network_params, - ) + with set_exploration_type( + ExplorationType.RANDOM + ), self.actor_network_params.to_module(self.actor_network): + dist = self.actor_network.get_dist(tensordict) a_reparm = dist.rsample() log_prob = dist.log_prob(a_reparm) @@ -679,11 +666,11 @@ def _compute_target_v2(self, tensordict) -> Tensor: tensordict = tensordict.clone(False) # get actions and log-probs with torch.no_grad(): - with set_exploration_type(ExplorationType.RANDOM): + with set_exploration_type( + ExplorationType.RANDOM + ), self.actor_network_params.to_module(self.actor_network): next_tensordict = tensordict.get("next").clone(False) - next_dist = self.actor_network.get_dist( - next_tensordict, params=self.actor_network_params - ) + next_dist = self.actor_network.get_dist(next_tensordict) next_action = next_dist.rsample() next_tensordict.set(self.tensor_keys.action, next_action) next_sample_log_prob = next_dist.log_prob(next_action) @@ -735,16 +722,11 @@ def _value_loss( ) -> Tuple[Tensor, Dict[str, Tensor]]: # value loss td_copy = tensordict.select(*self.value_network.in_keys).detach() - self.value_network( - td_copy, - params=self.value_network_params, - ) + with self.value_network_params.to_module(self.value_network): + self.value_network(td_copy) pred_val = td_copy.get(self.tensor_keys.value).squeeze(-1) - - action_dist = self.actor_network.get_dist( - td_copy, - params=self.target_actor_network_params, - ) # resample an action + with self.target_actor_network_params.to_module(self.actor_network): + action_dist = self.actor_network.get_dist(td_copy) # resample an action action = action_dist.rsample() td_copy.set(self.tensor_keys.action, action, inplace=False) @@ -990,8 +972,6 @@ def __init__( separate_losses: bool = False, ): self._in_keys = None - if not _has_functorch: - raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR super().__init__() self._set_deprecated_ctor_keys(priority_key=priority_key) @@ -1069,7 +1049,7 @@ def __init__( self.register_buffer( "target_entropy", torch.tensor(target_entropy, device=device) ) - self._vmap_qnetworkN0 = vmap(self.qvalue_network, (None, 0)) + self._vmap_qnetworkN0 = _vmap_func(self.qvalue_network, (None, 0)) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: @@ -1153,9 +1133,8 @@ def _compute_target(self, tensordict) -> Tensor: next_tensordict = tensordict.get("next").clone(False) # get probs and log probs for actions computed from "next" - next_dist = self.actor_network.get_dist( - next_tensordict, params=self.actor_network_params - ) + with self.actor_network_params.to_module(self.actor_network): + next_dist = self.actor_network.get_dist(next_tensordict) next_prob = next_dist.probs next_log_prob = torch.log(torch.where(next_prob == 0, 1e-8, next_prob)) @@ -1220,10 +1199,8 @@ def _actor_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: # get probs and log probs for actions - dist = self.actor_network.get_dist( - tensordict, - params=self.actor_network_params, - ) + with self.actor_network_params.to_module(self.actor_network): + dist = self.actor_network.get_dist(tensordict) prob = dist.probs log_prob = torch.log(torch.where(prob == 0, 1e-8, prob)) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 9912c143ae6..524601793f6 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -15,27 +15,17 @@ from torchrl.envs.utils import step_mdp from torchrl.objectives.common import LossModule + from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_WARNING, default_value_kwargs, distance_loss, ValueEstimators, + _vmap_func, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator -try: - try: - from torch import vmap - except ImportError: - from functorch import vmap - - FUNCTORCH_ERR = "" - _has_functorch = True -except ImportError as err: - FUNCTORCH_ERR = str(err) - _has_functorch = False - class TD3Loss(LossModule): """TD3 Loss module. @@ -229,10 +219,6 @@ def __init__( priority_key: str = None, separate_losses: bool = False, ) -> None: - if not _has_functorch: - raise ImportError( - f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}" - ) super().__init__() self._in_keys = None @@ -310,8 +296,8 @@ def __init__( if gamma is not None: warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma - self._vmap_qvalue_network00 = vmap(self.qvalue_network) - self._vmap_actor_network00 = vmap(self.actor_network) + self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network) + self._vmap_actor_network00 = _vmap_func(self.actor_network) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: @@ -359,9 +345,8 @@ def _cached_stack_actor_params(self): def actor_loss(self, tensordict): tensordict_actor_grad = tensordict.select(*self.actor_network.in_keys) - tensordict_actor_grad = self.actor_network( - tensordict_actor_grad, self.actor_network_params - ) + with self.actor_network_params.to_module(self.actor_network): + tensordict_actor_grad = self.actor_network(tensordict_actor_grad) actor_loss_td = tensordict_actor_grad.select( *self.qvalue_network.in_keys ).expand( @@ -395,9 +380,8 @@ def value_loss(self, tensordict): next_td_actor = step_mdp(tensordict).select( *self.actor_network.in_keys ) # next_observation -> - next_td_actor = self.actor_network( - next_td_actor, self.target_actor_network_params - ) + with self.target_actor_network_params.to_module(self.actor_network): + next_td_actor = self.actor_network(next_td_actor) next_action = (next_td_actor.get(self.tensor_keys.action) + noise).clamp( self.min_action, self.max_action ) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index c7e558fe236..c43c8c2a475 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -468,11 +468,14 @@ def new_fun(self, netname=None): return new_fun -def vmap_func(module, *args, **kwargs): +def _vmap_func(module, *args, func=None, **kwargs): def decorated_module(*module_args_params): params = module_args_params[-1] module_args = module_args_params[:-1] with params.to_module(module): - return module(*module_args) + if func is None: + return module(*module_args) + else: + return getattr(module, func)(*module_args) - return vmap(decorated_module, *args, **kwargs) + return vmap(decorated_module, *args, **kwargs) # noqa: TOR101 diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 626a07d162f..2d2d9b20a8d 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -25,7 +25,7 @@ from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import step_mdp -from torchrl.objectives.utils import hold_out_net +from torchrl.objectives.utils import hold_out_net, _vmap_func from torchrl.objectives.value.functional import ( generalized_advantage_estimate, td0_return_estimate, @@ -136,14 +136,8 @@ def _call_value_nets( "params and next_params must be either both provided or not." ) elif params is not None: - params_stack = torch.stack([params, next_params], 0) - - def call_value_net(data_in, params_stack): - with params_stack.to_module(value_net): - out = value_net(data_in) - return out - - data_out = vmap(call_value_net, (0, 0))(data_in, params_stack) + params_stack = torch.stack([params, next_params], 0).contiguous() + data_out = _vmap_func(value_net, (0, 0))(data_in, params_stack) else: data_out = vmap(value_net, (0,))(data_in) value_est = data_out.get(value_key) @@ -568,7 +562,9 @@ def forward( params = params.detach() if target_params is None: target_params = params.clone(False) - with hold_out_net(self.value_network): + with hold_out_net(self.value_network) if ( + params is None and target_params is None + ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params value, next_value = _call_value_nets( @@ -769,7 +765,9 @@ def forward( params = params.detach() if target_params is None: target_params = params.clone(False) - with hold_out_net(self.value_network): + with hold_out_net(self.value_network) if ( + params is None and target_params is None + ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params value, next_value = _call_value_nets( @@ -980,7 +978,9 @@ def forward( params = params.detach() if target_params is None: target_params = params.clone(False) - with hold_out_net(self.value_network): + with hold_out_net(self.value_network) if ( + params is None and target_params is None + ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params value, next_value = _call_value_nets( @@ -1308,7 +1308,9 @@ def value_estimate( params = params.detach() if target_params is None: target_params = params.clone(False) - with hold_out_net(self.value_network): + with hold_out_net(self.value_network) if ( + params is None and target_params is None + ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params value, next_value = _call_value_nets(