Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 24, 2023
1 parent 957c93a commit 44f4f55
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 5 deletions.
10 changes: 8 additions & 2 deletions test/assets/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,20 @@ def get_minibatch():
batch_size=16,
block_size=33,
tensorclass_type=PromptData,
dataset_name="test/datasets_mini/openai_summarize_tldr",
dataset_name="CarperAI/openai_summarize_tldr",
device="cpu",
infinite=False,
prefetch=0,
split="train",
from_disk=True,
from_disk=False,
root_dir=tmpdir,
)
for data in dl:
data = data.clone().memmap_("test/datasets_mini/tldr_batch/")
break
print("done")


if __name__ == "__main__":
# generate_small_dataset()
get_minibatch()
1 change: 1 addition & 0 deletions torchrl/data/rlhf/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def load(self):
data_dir = root_dir / str(Path(self.dataset_name).name).split("-")[0]
data_dir_total = data_dir / split / str(max_length)
# search for data
print(data_dir_total)
if os.path.exists(data_dir_total):
dataset = TensorDict.load_memmap(data_dir_total)
return dataset
Expand Down
6 changes: 3 additions & 3 deletions torchrl/envs/transforms/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch import nn
from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs.transforms.transforms import Transform
from torchrl.envs.transforms.utils import _set_missing_tolerance
from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param


class KLRewardTransform(Transform):
Expand Down Expand Up @@ -112,9 +112,9 @@ def __init__(

# check that the model has parameters
params = TensorDict.from_module(actor)
with params.apply(lambda t: t.data.to("meta")).to_module(actor):
with params.apply(_stateless_param).to_module(actor):
# copy a stateless actor
self.functional_actor = deepcopy(actor)
self.__dict__["functional_actor"] = deepcopy(actor)
# we need to register these params as buffer to have `to` and similar
# methods work properly

Expand Down
9 changes: 9 additions & 0 deletions torchrl/envs/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


import torch
from torch import nn


def check_finite(tensor: torch.Tensor):
Expand Down Expand Up @@ -59,3 +60,11 @@ def _get_reset(reset_key, tensordict):
if _reset.ndim > parent_td.ndim:
_reset = _reset.flatten(parent_td.ndim, -1).any(-1)
return _reset


def _stateless_param(param):
is_param = isinstance(param, nn.Parameter)
param = param.data.to("meta")
if is_param:
return nn.Parameter(param, requires_grad=False)
return param

0 comments on commit 44f4f55

Please sign in to comment.