Skip to content

Commit

Permalink
Use weights_only for load (#1933)
Browse files Browse the repository at this point in the history
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
  • Loading branch information
kit1980 and qgallouedec authored Aug 26, 2024
1 parent 2fbc0f4 commit de024ec
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/scripts/ddpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, *, dtype, model_id, model_filename):
cached_path = hf_hub_download(model_id, model_filename)
except EntryNotFoundError:
cached_path = os.path.join(model_id, model_filename)
state_dict = torch.load(cached_path, map_location=torch.device("cpu"))
state_dict = torch.load(cached_path, map_location=torch.device("cpu"), weights_only=True)
self.mlp.load_state_dict(state_dict)
self.dtype = dtype
self.eval()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_peft_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_save_pretrained_peft(self):
assert os.path.exists(f"{tmp_dir}/adapter_config.json"), f"{tmp_dir}/adapter_config.json does not exist"
# check also for `pytorch_model.bin` and make sure it only contains `v_head` weights
assert os.path.exists(f"{tmp_dir}/pytorch_model.bin"), f"{tmp_dir}/pytorch_model.bin does not exist"
maybe_v_head = torch.load(f"{tmp_dir}/pytorch_model.bin")
maybe_v_head = torch.load(f"{tmp_dir}/pytorch_model.bin", weights_only=True)
# check that only keys that starts with `v_head` are in the dict
assert all(
k.startswith("v_head") for k in maybe_v_head.keys()
Expand Down
2 changes: 1 addition & 1 deletion trl/models/auxiliary_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, *, dtype, model_id, model_filename):
cached_path = hf_hub_download(model_id, model_filename)
except EntryNotFoundError:
cached_path = os.path.join(model_id, model_filename)
state_dict = torch.load(cached_path, map_location=torch.device("cpu"))
state_dict = torch.load(cached_path, map_location=torch.device("cpu"), weights_only=True)
self.mlp.load_state_dict(state_dict)
self.dtype = dtype
self.eval()
Expand Down

0 comments on commit de024ec

Please sign in to comment.