Skip to content

Commit

Permalink
[CI] Fix CI (#1711)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 27, 2023
1 parent bc7595f commit 0f93943
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 49 deletions.
12 changes: 9 additions & 3 deletions test/assets/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
# LICENSE file in the root directory of this source tree.

"""Script used to generate the mini datasets."""
import multiprocessing as mp

try:
mp.set_start_method("spawn")
except Exception:
pass
from tempfile import TemporaryDirectory

from datasets import Dataset, DatasetDict, load_dataset
Expand Down Expand Up @@ -36,12 +42,13 @@ def get_minibatch():
batch_size=16,
block_size=33,
tensorclass_type=PromptData,
dataset_name="CarperAI/openai_summarize_tldr",
dataset_name="../datasets_mini/openai_summarize_tldr",
device="cpu",
num_workers=2,
infinite=False,
prefetch=0,
split="train",
from_disk=False,
from_disk=True,
root_dir=tmpdir,
)
for data in dl:
Expand All @@ -51,5 +58,4 @@ def get_minibatch():


if __name__ == "__main__":
# generate_small_dataset()
get_minibatch()
Binary file modified test/assets/tldr_batch.zip
Binary file not shown.
36 changes: 22 additions & 14 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -5985,8 +5985,6 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses):
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
def test_ppo_diff(self, loss_class, device, gradient_mode, advantage):
if pack_version.parse(torch.__version__) > pack_version.parse("1.14"):
raise pytest.skip("make_functional_with_buffers needs to be changed")
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)

Expand Down Expand Up @@ -6018,34 +6016,44 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage):

loss_fn = loss_class(actor, value, gamma=0.9, loss_critic_type="l2")

floss_fn, params, buffers = make_functional_with_buffers(loss_fn)
params = TensorDict.from_module(loss_fn, as_module=True)

# fill params with zero
for p in params:
p.data.zero_()
def zero_param(p):
if isinstance(p, nn.Parameter):
p.data.zero_()

params.apply(zero_param)

# assert len(list(floss_fn.parameters())) == 0
if advantage is not None:
advantage(td)
loss = floss_fn(params, buffers, td)
with params.to_module(loss_fn):
if advantage is not None:
advantage(td)
loss = loss_fn(td)

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
# check that grads are independent and non null
named_parameters = loss_fn.named_parameters()
for (name, _), p in zip(named_parameters, params):
for name, p in params.items(True, True):
if isinstance(name, tuple):
name = "-".join(name)
if not isinstance(p, nn.Parameter):
continue
if p.grad is not None and p.grad.norm() > 0.0:
assert "actor" not in name
assert "critic" in name
if p.grad is None:
assert "actor" in name
assert "critic" not in name

for param in params:
param.grad = None
for p in params.values(True, True):
p.grad = None
loss_objective.backward()
named_parameters = loss_fn.named_parameters()

for (name, other_p), p in zip(named_parameters, params):
for (name, other_p) in named_parameters:
p = params.get(tuple(name.split(".")))
assert other_p.shape == p.shape
assert other_p.dtype == p.dtype
assert other_p.device == p.device
Expand All @@ -6055,7 +6063,7 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage):
if p.grad is None:
assert "actor" not in name
assert "critic" in name
for param in params:
for param in params.values(True, True):
param.grad = None

@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/rlhf/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def load(self):
data_dir = root_dir / str(Path(self.dataset_name).name).split("-")[0]
data_dir_total = data_dir / split / str(max_length)
# search for data
print(data_dir_total)
print("Looking for data in", data_dir_total)
if os.path.exists(data_dir_total):
dataset = TensorDict.load_memmap(data_dir_total)
return dataset
Expand Down
17 changes: 14 additions & 3 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dataclasses import dataclass
from typing import Iterator, List, Optional, Tuple

import torch
from tensordict import TensorDict, TensorDictBase

from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams
Expand Down Expand Up @@ -288,11 +289,11 @@ def _compare_and_expand(param):

# set the functional module: we need to convert the params to non-differentiable params
# otherwise they will appear twice in parameters
p = TensorDict.from_module(module)
with params.detach().to("meta").to_module(module):
with params.apply(_make_meta_params, device=torch.device("meta")).to_module(
module
):
# avoid buffers and params being exposed
self.__dict__[module_name] = deepcopy(module)
assert (p == TensorDict.from_module(module)).all()

name_params_target = "target_" + module_name
if create_target_params:
Expand Down Expand Up @@ -445,3 +446,13 @@ def __call__(self, x):
x.data.clone() if self.clone else x.data, requires_grad=False
)
return x.data.clone() if self.clone else x.data


def _make_meta_params(param):
is_param = isinstance(param, nn.Parameter)

pd = param.detach().to("meta")

if is_param:
pd = nn.Parameter(pd, requires_grad=False)
return pd
2 changes: 1 addition & 1 deletion torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def __init__(
self._target_entropy = target_entropy
self._action_spec = action_spec
if self._version == 1:
self.actor_critic = ActorCriticWrapper(
self.__dict__["actor_critic"] = ActorCriticWrapper(
self.actor_network, self.value_network
)
if gamma is not None:
Expand Down
33 changes: 11 additions & 22 deletions torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def __init__(
self._tensor_keys = None
self.differentiable = differentiable
self.skip_existing = skip_existing
self.value_network = value_network
self.__dict__["value_network"] = value_network
self.dep_keys = {}
self.shifted = shifted

Expand Down Expand Up @@ -471,6 +471,7 @@ class TD0Estimator(ValueEstimatorBase):
of the advantage entry. Defaults to ``"value_target"``.
value_key (str or tuple of str, optional): [Deprecated] the value key to
read from the input tensordict. Defaults to ``"state_value"``.
device (torch.device, optional): device of the module.
"""

Expand All @@ -486,6 +487,7 @@ def __init__(
value_target_key: NestedKey = None,
value_key: NestedKey = None,
skip_existing: Optional[bool] = None,
device: Optional[torch.device] = None,
):
super().__init__(
value_network=value_network,
Expand All @@ -496,10 +498,6 @@ def __init__(
value_key=value_key,
skip_existing=skip_existing,
)
try:
device = next(value_network.parameters()).device
except (AttributeError, StopIteration):
device = torch.device("cpu")
self.register_buffer("gamma", torch.tensor(gamma, device=device))
self.average_rewards = average_rewards

Expand Down Expand Up @@ -675,6 +673,7 @@ class TD1Estimator(ValueEstimatorBase):
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
device (torch.device, optional): device of the module.
"""

Expand All @@ -690,6 +689,7 @@ def __init__(
value_target_key: NestedKey = None,
value_key: NestedKey = None,
shifted: bool = False,
device: Optional[torch.device] = None,
):
super().__init__(
value_network=value_network,
Expand All @@ -700,10 +700,6 @@ def __init__(
shifted=shifted,
skip_existing=skip_existing,
)
try:
device = next(value_network.parameters()).device
except (AttributeError, StopIteration):
device = torch.device("cpu")
self.register_buffer("gamma", torch.tensor(gamma, device=device))
self.average_rewards = average_rewards

Expand Down Expand Up @@ -883,6 +879,7 @@ class TDLambdaEstimator(ValueEstimatorBase):
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
device (torch.device, optional): device of the module.
"""

Expand All @@ -900,6 +897,7 @@ def __init__(
value_target_key: NestedKey = None,
value_key: NestedKey = None,
shifted: bool = False,
device: Optional[torch.device] = None,
):
super().__init__(
value_network=value_network,
Expand All @@ -910,10 +908,6 @@ def __init__(
skip_existing=skip_existing,
shifted=shifted,
)
try:
device = next(value_network.parameters()).device
except (AttributeError, StopIteration):
device = torch.device("cpu")
self.register_buffer("gamma", torch.tensor(gamma, device=device))
self.register_buffer("lmbda", torch.tensor(lmbda, device=device))
self.average_rewards = average_rewards
Expand Down Expand Up @@ -1113,6 +1107,7 @@ class GAE(ValueEstimatorBase):
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
device (torch.device, optional): device of the module.
GAE will return an :obj:`"advantage"` entry containing the advange value. It will also
return a :obj:`"value_target"` entry with the return value that is to be used
Expand Down Expand Up @@ -1142,6 +1137,7 @@ def __init__(
value_target_key: NestedKey = None,
value_key: NestedKey = None,
shifted: bool = False,
device: Optional[torch.device] = None,
):
super().__init__(
shifted=shifted,
Expand All @@ -1152,10 +1148,6 @@ def __init__(
value_key=value_key,
skip_existing=skip_existing,
)
try:
device = next(value_network.parameters()).device
except (AttributeError, StopIteration):
device = torch.device("cpu")
self.register_buffer("gamma", torch.tensor(gamma, device=device))
self.register_buffer("lmbda", torch.tensor(lmbda, device=device))
self.average_gae = average_gae
Expand Down Expand Up @@ -1403,6 +1395,7 @@ class VTrace(ValueEstimatorBase):
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
device (torch.device, optional): device of the module.
VTrace will return an :obj:`"advantage"` entry containing the advantage value. It will also
return a :obj:`"value_target"` entry with the V-Trace target value.
Expand All @@ -1429,6 +1422,7 @@ def __init__(
value_target_key: Optional[NestedKey] = None,
value_key: Optional[NestedKey] = None,
shifted: bool = False,
device: Optional[torch.device] = None,
):
super().__init__(
shifted=shifted,
Expand All @@ -1439,11 +1433,6 @@ def __init__(
value_key=value_key,
skip_existing=skip_existing,
)
try:
device = next(value_network.parameters()).device
except (AttributeError, StopIteration):
device = torch.device("cpu")

if not isinstance(gamma, torch.Tensor):
gamma = torch.tensor(gamma, device=device)
if not isinstance(rho_thresh, torch.Tensor):
Expand Down
12 changes: 7 additions & 5 deletions tutorials/sphinx-tutorials/rb_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
# replay buffer is a straightforward process, as shown in the following
# example:
#
import tempfile

from torchrl.data import ReplayBuffer

Expand Down Expand Up @@ -175,9 +176,8 @@
######################################################################
# We can also customize the storage location on disk:
#
buffer_lazymemmap = ReplayBuffer(
storage=LazyMemmapStorage(size, scratch_dir="/tmp/memmap/")
)
tempdir = tempfile.TemporaryDirectory()
buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size, scratch_dir=tempdir))
buffer_lazymemmap.extend(data)
print(f"The buffer has {len(buffer_lazymemmap)} elements")
print("the 'a' tensor is stored in", buffer_lazymemmap._storage._storage["a"].filename)
Expand Down Expand Up @@ -207,8 +207,9 @@

from torchrl.data import TensorDictReplayBuffer

tempdir = tempfile.TemporaryDirectory()
buffer_lazymemmap = TensorDictReplayBuffer(
storage=LazyMemmapStorage(size, scratch_dir="/tmp/memmap/"), batch_size=12
storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12
)
buffer_lazymemmap.extend(data)
print(f"The buffer has {len(buffer_lazymemmap)} elements")
Expand Down Expand Up @@ -248,8 +249,9 @@ class MyData:
batch_size=[1000],
)

tempdir = tempfile.TemporaryDirectory()
buffer_lazymemmap = TensorDictReplayBuffer(
storage=LazyMemmapStorage(size, scratch_dir="/tmp/memmap/"), batch_size=12
storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12
)
buffer_lazymemmap.extend(data)
print(f"The buffer has {len(buffer_lazymemmap)} elements")
Expand Down
1 change: 1 addition & 0 deletions tutorials/sphinx-tutorials/torchrl_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# will pass the arguments and keyword arguments to the root library builder.
#
# With gym, it means that building an environment is as easy as:

# sphinx_gallery_start_ignore
import warnings

Expand Down

0 comments on commit 0f93943

Please sign in to comment.