Skip to content

Commit

Permalink
add argument to pass shared tensors keys to discard (#2696)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanouticelina committed Dec 6, 2024
1 parent d7bead5 commit f9beb79
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 11 deletions.
57 changes: 46 additions & 11 deletions src/huggingface_hub/serialization/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def save_torch_model(
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
metadata: Optional[Dict[str, str]] = None,
safe_serialization: bool = True,
is_main_process: bool = True,
shared_tensors_to_discard: Optional[List[str]] = None,
):
"""
Saves a given torch model to disk, handling sharding and shared tensors issues.
Expand All @@ -64,6 +66,12 @@ def save_torch_model(
</Tip>
<Tip warning={true}>
If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
</Tip>
Args:
model (`torch.nn.Module`):
The model to save on disk.
Expand All @@ -88,6 +96,13 @@ def save_torch_model(
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
in a future version.
is_main_process (`bool`, *optional*):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions. Defaults to True.
shared_tensors_to_discard (`List[str]`, *optional*):
List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
detected, it will drop the first name alphabetically.
Example:
Expand All @@ -112,6 +127,8 @@ def save_torch_model(
metadata=metadata,
safe_serialization=safe_serialization,
save_directory=save_directory,
is_main_process=is_main_process,
shared_tensors_to_discard=shared_tensors_to_discard,
)


Expand All @@ -124,6 +141,8 @@ def save_torch_state_dict(
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
metadata: Optional[Dict[str, str]] = None,
safe_serialization: bool = True,
is_main_process: bool = True,
shared_tensors_to_discard: Optional[List[str]] = None,
) -> None:
"""
Save a model state dictionary to the disk, handling sharding and shared tensors issues.
Expand All @@ -147,6 +166,12 @@ def save_torch_state_dict(
</Tip>
<Tip warning={true}>
If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
</Tip>
Args:
state_dict (`Dict[str, torch.Tensor]`):
The state dictionary to save.
Expand All @@ -171,6 +196,13 @@ def save_torch_state_dict(
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
in a future version.
is_main_process (`bool`, *optional*):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions. Defaults to True.
shared_tensors_to_discard (`List[str]`, *optional*):
List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
detected, it will drop the first name alphabetically.
Example:
Expand All @@ -192,7 +224,8 @@ def save_torch_state_dict(
else constants.PYTORCH_WEIGHTS_FILE_PATTERN
)

# Imports correct library
if metadata is None:
metadata = {}
if safe_serialization:
try:
from safetensors.torch import save_file as save_file_fn
Expand All @@ -201,7 +234,13 @@ def save_torch_state_dict(
"Please install `safetensors` to use safe serialization. "
"You can install it with `pip install safetensors`."
) from e

# Clean state dict for safetensors
state_dict = _clean_state_dict_for_safetensors(
state_dict,
metadata,
force_contiguous=force_contiguous,
shared_tensors_to_discard=shared_tensors_to_discard,
)
else:
from torch import save as save_file_fn # type: ignore[assignment]

Expand All @@ -210,13 +249,6 @@ def save_torch_state_dict(
"pickled models from untrusted sources. If you intend to share your model, we strongly recommend "
"using safe serialization by installing `safetensors` with `pip install safetensors`."
)

# Clean state dict for safetensors
if metadata is None:
metadata = {}
if safe_serialization:
state_dict = _clean_state_dict_for_safetensors(state_dict, metadata, force_contiguous=force_contiguous)

# Split dict
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
Expand Down Expand Up @@ -459,15 +491,18 @@ def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:


def _clean_state_dict_for_safetensors(
state_dict: Dict[str, "torch.Tensor"], metadata: Dict[str, str], force_contiguous: bool = True
state_dict: Dict[str, "torch.Tensor"],
metadata: Dict[str, str],
force_contiguous: bool = True,
shared_tensors_to_discard: Optional[List[str]] = None,
):
"""Remove shared tensors from state_dict and update metadata accordingly (for reloading).
Warning: `state_dict` and `metadata` are mutated in-place!
Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L155.
"""
to_removes = _remove_duplicate_names(state_dict)
to_removes = _remove_duplicate_names(state_dict, discard_names=shared_tensors_to_discard)
for kept_name, to_remove_group in to_removes.items():
for to_remove in to_remove_group:
if metadata is None:
Expand Down
53 changes: 53 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ def test_save_torch_model(mocker: MockerFixture, tmp_path: Path) -> None:
max_shard_size="3GB",
metadata={"foo": "bar"},
safe_serialization=True,
is_main_process=True,
shared_tensors_to_discard=None,
)
safe_state_dict_mock.assert_called_once_with(
state_dict=model_mock.state_dict.return_value,
Expand All @@ -273,6 +275,8 @@ def test_save_torch_model(mocker: MockerFixture, tmp_path: Path) -> None:
max_shard_size="3GB",
metadata={"foo": "bar"},
safe_serialization=True,
is_main_process=True,
shared_tensors_to_discard=None,
)


Expand Down Expand Up @@ -414,6 +418,55 @@ def test_save_torch_state_dict_shared_layers_sharded(
assert "shared_2" not in state_dict


def test_save_torch_state_dict_discard_selected_sharded(
tmp_path: Path, torch_state_dict_shared_layers: Dict[str, "torch.Tensor"]
) -> None:
from safetensors.torch import load_file

save_torch_state_dict(
torch_state_dict_shared_layers,
tmp_path,
max_shard_size=2,
safe_serialization=True,
shared_tensors_to_discard=["shared_1"],
)
index_file = tmp_path / "model.safetensors.index.json"
index = json.loads(index_file.read_text())

assert index["metadata"]["shared_1"] == "shared_2"

for filename in index["weight_map"].values():
state_dict = load_file(tmp_path / filename)
assert "shared_1" not in state_dict


def test_save_torch_state_dict_discard_selected_not_sharded(
tmp_path: Path, torch_state_dict_shared_layers: Dict[str, "torch.Tensor"]
) -> None:
from safetensors.torch import load_file

save_torch_state_dict(
torch_state_dict_shared_layers,
tmp_path,
safe_serialization=True,
shared_tensors_to_discard=["shared_1"],
)
safetensors_file = tmp_path / "model.safetensors"
assert safetensors_file.is_file()

# Check shared layer not duplicated in file
state_dict = load_file(safetensors_file)
assert "shared_1" not in state_dict
assert "shared_2" in state_dict

# Check shared layer info in metadata
file_bytes = safetensors_file.read_bytes()
metadata_str = file_bytes[
8 : struct.unpack("<Q", file_bytes[:8])[0] + 8
].decode() # TODO: next time add helper for this
assert json.loads(metadata_str)["__metadata__"]["shared_1"] == "shared_2"


def test_split_torch_state_dict_into_shards(
tmp_path: Path, torch_state_dict_shared_layers_tensor_subclass: Dict[str, "torch.Tensor"]
):
Expand Down

0 comments on commit f9beb79

Please sign in to comment.