Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Feb 3, 2025
2 parents 257e198 + 4f9df7c commit 2cbc053
Show file tree
Hide file tree
Showing 20 changed files with 390 additions and 56 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ jobs:
cd benchmarks/
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
export COMPOSITE_LP_AGGREGATE=0
export TD_GET_DEFAULTS_TO_NONE=1
python3 -m pytest -vvv --rank 0 --benchmark-json output.json --ignore test_collectors_benchmark.py
- name: Store benchmark results
Expand Down Expand Up @@ -131,6 +132,7 @@ jobs:
cd benchmarks/
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
export COMPOSITE_LP_AGGREGATE=0
export TD_GET_DEFAULTS_TO_NONE=1
python3 -m pytest -vvv --rank 0 --benchmark-json output.json --ignore test_collectors_benchmark.py
- name: Store benchmark results
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/benchmarks_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
cd benchmarks/
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
export COMPOSITE_LP_AGGREGATE=0
export TD_GET_DEFAULTS_TO_NONE=1
RUN_BENCHMARK="python3 -m pytest -vvv --rank 0 --ignore test_collectors_benchmark.py --benchmark-json "
git checkout ${{ github.event.pull_request.base.sha }}
Expand Down Expand Up @@ -141,6 +142,7 @@ jobs:
cd benchmarks/
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
export COMPOSITE_LP_AGGREGATE=0
export TD_GET_DEFAULTS_TO_NONE=1
RUN_BENCHMARK="python3 -m pytest -vvv --rank 0 --ignore test_collectors_benchmark.py --benchmark-json "
git checkout ${{ github.event.pull_request.base.sha }}
Expand Down
25 changes: 21 additions & 4 deletions benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from tensordict import TensorDict
from tensordict.nn import (
composite_lp_aggregate,
InteractionType,
NormalParamExtractor,
ProbabilisticTensorDictModule as ProbMod,
Expand Down Expand Up @@ -785,11 +786,15 @@ def test_a2c_speed(
device=device,
)
batch = [batch, T]
if composite_lp_aggregate():
raise RuntimeError(
"Expected composite_lp_aggregate() to return False. Use set_composite_lp_aggregate or COMPOSITE_LP_AGGREGATE env variable."
)
td = TensorDict(
{
"obs": torch.randn(*batch, n_obs),
"action": torch.randn(*batch, n_act),
"sample_log_prob": torch.randn(*batch),
"action_log_prob": torch.randn(*batch),
"done": torch.zeros(*batch, 1, dtype=torch.bool),
"next": {
"obs": torch.randn(*batch, n_obs),
Expand Down Expand Up @@ -884,11 +889,15 @@ def test_ppo_speed(
device=device,
)
batch = [batch, T]
if composite_lp_aggregate():
raise RuntimeError(
"Expected composite_lp_aggregate() to return False. Use set_composite_lp_aggregate or COMPOSITE_LP_AGGREGATE env variable."
)
td = TensorDict(
{
"obs": torch.randn(*batch, n_obs),
"action": torch.randn(*batch, n_act),
"sample_log_prob": torch.randn(*batch),
"action_log_prob": torch.randn(*batch),
"done": torch.zeros(*batch, 1, dtype=torch.bool),
"next": {
"obs": torch.randn(*batch, n_obs),
Expand Down Expand Up @@ -983,11 +992,15 @@ def test_reinforce_speed(
device=device,
)
batch = [batch, T]
if composite_lp_aggregate():
raise RuntimeError(
"Expected composite_lp_aggregate() to return False. Use set_composite_lp_aggregate or COMPOSITE_LP_AGGREGATE env variable."
)
td = TensorDict(
{
"obs": torch.randn(*batch, n_obs),
"action": torch.randn(*batch, n_act),
"sample_log_prob": torch.randn(*batch),
"action_log_prob": torch.randn(*batch),
"done": torch.zeros(*batch, 1, dtype=torch.bool),
"next": {
"obs": torch.randn(*batch, n_obs),
Expand Down Expand Up @@ -1089,11 +1102,15 @@ def test_iql_speed(
device=device,
)
batch = [batch, T]
if composite_lp_aggregate():
raise RuntimeError(
"Expected composite_lp_aggregate() to return False. Use set_composite_lp_aggregate or COMPOSITE_LP_AGGREGATE env variable."
)
td = TensorDict(
{
"obs": torch.randn(*batch, n_obs),
"action": torch.randn(*batch, n_act),
"sample_log_prob": torch.randn(*batch),
"action_log_prob": torch.randn(*batch),
"done": torch.zeros(*batch, 1, dtype=torch.bool),
"next": {
"obs": torch.randn(*batch, n_obs),
Expand Down
24 changes: 19 additions & 5 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,10 @@ def __post_init__(self):

actor_network: TensorDictModule
critic_network: TensorDictModule
actor_network_params: TensorDictParams
critic_network_params: TensorDictParams
target_actor_network_params: TensorDictParams
target_critic_network_params: TensorDictParams
actor_network_params: TensorDictParams | None
critic_network_params: TensorDictParams | None
target_actor_network_params: TensorDictParams | None
target_critic_network_params: TensorDictParams | None

def __init__(
self,
Expand Down Expand Up @@ -521,6 +521,13 @@ def loss_critic(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, float]:
loss_value,
self.loss_critic_type,
)
self._clear_weakrefs(
tensordict,
"actor_network_params",
"critic_network_params",
"target_actor_network_params",
"target_critic_network_params",
)
if self.critic_coef is not None:
return self.critic_coef * loss_value, clip_fraction
return loss_value, clip_fraction
Expand Down Expand Up @@ -559,7 +566,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
if name.startswith("loss_")
else value,
batch_size=[],
)
self._clear_weakrefs(
tensordict,
td_out,
"actor_network_params",
"critic_network_params",
"target_actor_network_params",
"target_critic_network_params",
)
return td_out

Expand Down
19 changes: 15 additions & 4 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@
from torchrl.objectives.value import ValueEstimatorBase

try:
from torch.compiler import is_dynamo_compiling
from torch.compiler import is_compiling
except ImportError:
from torch._dynamo import is_compiling as is_dynamo_compiling
from torch._dynamo import is_compiling


def _updater_check_forward_prehook(module, *args, **kwargs):
if (
not all(module._has_update_associated.values())
and RL_WARNINGS
and not is_dynamo_compiling()
and not is_compiling()
):
warnings.warn(
module.TARGET_NET_WARNING,
Expand Down Expand Up @@ -415,6 +415,7 @@ def _compare_and_expand(param):
params.set(key, parameter.data)

setattr(self, param_name, params)
assert getattr(self, param_name) is params, getattr(self, param_name)

# Set the module in the __dict__ directly to avoid listing its params
# A deepcopy with meta device could be used but that assumes that the model is copyable!
Expand All @@ -433,6 +434,16 @@ def _compare_and_expand(param):
setattr(self, name_params_target + "_params", target_params)
self._has_update_associated[module_name] = not create_target_params

def _clear_weakrefs(self, *tds):
if is_compiling():
# Waiting for weakrefs reconstruct to be supported by compile
for td in tds:
if isinstance(td, str):
td = getattr(self, td, None)
if not is_tensor_collection(td):
continue
td.clear_refs_for_compile_()

def __getattr__(self, item):
if item.startswith("target_") and item.endswith("_params"):
params = self._modules.get(item, None)
Expand All @@ -443,7 +454,7 @@ def __getattr__(self, item):
elif (
not self._has_update_associated[item[7:-7]]
and RL_WARNINGS
and not is_dynamo_compiling()
and not is_compiling()
):
# no updater associated
warnings.warn(
Expand Down
54 changes: 52 additions & 2 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
}
if self.with_lagrange:
out["loss_alpha_prime"] = alpha_prime_loss.mean()
return TensorDict(out, [])
td_loss = TensorDict(out)
self._clear_weakrefs(
tensordict,
td_loss,
"actor_network_params",
"qvalue_network_params",
"target_actor_network_params",
"target_qvalue_network_params",
)
return td_loss

@property
@_cache_values
Expand All @@ -563,6 +572,13 @@ def actor_bc_loss(self, tensordict: TensorDictBase) -> Tensor:
bc_actor_loss = self._alpha * log_prob - bc_log_prob
bc_actor_loss = _reduce(bc_actor_loss, reduction=self.reduction)
metadata = {"bc_log_prob": bc_log_prob.mean().detach()}
self._clear_weakrefs(
tensordict,
"actor_network_params",
"qvalue_network_params",
"target_actor_network_params",
"target_qvalue_network_params",
)
return bc_actor_loss, metadata

def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
Expand Down Expand Up @@ -596,7 +612,13 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
metadata[self.tensor_keys.log_prob] = log_prob.detach()
actor_loss = self._alpha * log_prob - min_q_logprob
actor_loss = _reduce(actor_loss, reduction=self.reduction)

self._clear_weakrefs(
tensordict,
"actor_network_params",
"qvalue_network_params",
"target_actor_network_params",
"target_qvalue_network_params",
)
return actor_loss, metadata

def _get_policy_actions(self, data, actor_params, num_actions=10):
Expand Down Expand Up @@ -712,6 +734,13 @@ def q_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
loss_qval = _reduce(loss_qval, reduction=self.reduction)
td_error = (q_pred - target_value).pow(2)
metadata = {"td_error": td_error.detach()}
self._clear_weakrefs(
tensordict,
"actor_network_params",
"qvalue_network_params",
"target_actor_network_params",
"target_qvalue_network_params",
)
return loss_qval, metadata

def cql_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
Expand Down Expand Up @@ -855,6 +884,13 @@ def filter_and_repeat(name, x):
cql_q_loss = (cql_q1_loss + cql_q2_loss).mean(-1)
cql_q_loss = _reduce(cql_q_loss, reduction=self.reduction)

self._clear_weakrefs(
tensordict,
"actor_network_params",
"qvalue_network_params",
"target_actor_network_params",
"target_qvalue_network_params",
)
return cql_q_loss, {}

def alpha_prime_loss(self, tensordict: TensorDictBase) -> Tensor:
Expand All @@ -878,6 +914,13 @@ def alpha_prime_loss(self, tensordict: TensorDictBase) -> Tensor:

alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5
alpha_prime_loss = _reduce(alpha_prime_loss, reduction=self.reduction)
self._clear_weakrefs(
tensordict,
"actor_network_params",
"qvalue_network_params",
"target_actor_network_params",
"target_qvalue_network_params",
)
return alpha_prime_loss, {}

def alpha_loss(self, tensordict: TensorDictBase) -> Tensor:
Expand All @@ -889,6 +932,13 @@ def alpha_loss(self, tensordict: TensorDictBase) -> Tensor:
# placeholder
alpha_loss = torch.zeros_like(log_pi)
alpha_loss = _reduce(alpha_loss, reduction=self.reduction)
self._clear_weakrefs(
tensordict,
"actor_network_params",
"qvalue_network_params",
"target_actor_network_params",
"target_qvalue_network_params",
)
return alpha_loss, {}

@property
Expand Down
8 changes: 8 additions & 0 deletions torchrl/objectives/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
**value_metadata,
}
td_out = TensorDict(out)
self._clear_weakrefs(
tensordict,
td_out,
"actor_network_params",
"qvalue_network_params",
"target_actor_network_params",
"target_qvalue_network_params",
)
return td_out

@property
Expand Down
23 changes: 23 additions & 0 deletions torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata},
batch_size=[],
)
self._clear_weakrefs(
tensordict,
td_out,
"value_network_params",
"target_value_network_params",
"target_actor_network_params",
"actor_network_params",
)
return td_out

def loss_actor(
Expand All @@ -319,6 +327,14 @@ def loss_actor(
loss_actor = -td_copy.get(self.tensor_keys.state_action_value).squeeze(-1)
metadata = {}
loss_actor = _reduce(loss_actor, self.reduction)
self._clear_weakrefs(
tensordict,
loss_actor,
"value_network_params",
"target_value_network_params",
"target_actor_network_params",
"actor_network_params",
)
return loss_actor, metadata

def loss_value(
Expand Down Expand Up @@ -358,6 +374,13 @@ def loss_value(
"pred_value_max": pred_val.max(),
}
loss_value = _reduce(loss_value, self.reduction)
self._clear_weakrefs(
tensordict,
"value_network_params",
"target_value_network_params",
"target_actor_network_params",
"actor_network_params",
)
return loss_value, metadata

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
Expand Down
13 changes: 12 additions & 1 deletion torchrl/objectives/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
if name.startswith("loss_")
else value,
batch_size=[],
)
self._clear_weakrefs(
tensordict,
td_out,
"actor_network_params",
"target_actor_network_params",
)
return td_out

Expand Down Expand Up @@ -360,4 +365,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)
loss = _reduce(loss, reduction=self.reduction)
td_out = TensorDict(loss=loss)
self._clear_weakrefs(
tensordict,
td_out,
"actor_network_params",
"target_actor_network_params",
)
return td_out
Loading

0 comments on commit 2cbc053

Please sign in to comment.