Skip to content

Commit

Permalink
Add check for the whole state dict
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Jul 12, 2023
1 parent c33adec commit a20ab0d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/accelerate/utils/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def load_and_quantize_model(
- a path to a file containing a whole model state dict
- a path to a `.json` file containing the index to a sharded checkpoint
- a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
- a path to a folder containing a unique pytorch_model.bin file.
device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
name, once a given module name is inside, every submodule of it will be sent to the same device.
Expand Down
23 changes: 16 additions & 7 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,7 @@ def load_checkpoint_in_model(
- a path to a file containing a whole model state dict
- a path to a `.json` file containing the index to a sharded checkpoint
- a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
- a path to a folder containing a unique pytorch_model.bin file.
device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
name, once a given module name is inside, every submodule of it will be sent to the same device.
Expand Down Expand Up @@ -1233,17 +1234,25 @@ def load_checkpoint_in_model(
else:
checkpoint_files = [checkpoint]
elif os.path.isdir(checkpoint):
potential_index = [f for f in os.listdir(checkpoint) if f.endswith(".index.json")]
if len(potential_index) == 0:
raise ValueError(f"{checkpoint} is not a folder containing a `.index.json` file.")
elif len(potential_index) == 1:
index_filename = os.path.join(checkpoint, potential_index[0])
# check if the whole state dict is present
potential_state = [f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME]
if len(potential_state) == 1:
checkpoint_files = [os.path.join(checkpoint, potential_state[0])]
else:
raise ValueError(f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones.")
# otherwise check for sharded checkpoints
potential_index = [f for f in os.listdir(checkpoint) if f.endswith(".index.json")]
if len(potential_index) == 0:
raise ValueError(f"{checkpoint} is not a folder containing a `.index.json` file or a {WEIGHTS_NAME} file")
elif len(potential_index) == 1:
index_filename = os.path.join(checkpoint, potential_index[0])
else:
raise ValueError(
f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones."
)
else:
raise ValueError(
"`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded "
f"checkpoint, or a folder containing a sharded checkpoint, but got {checkpoint}."
f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}."
)

if index_filename is not None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def test_int8_serialization(self):
model_8bit_from_saved = load_and_quantize_model(
model_8bit_from_saved,
bnb_quantization_config,
weights_location=tmpdirname + "/pytorch_model.bin",
weights_location=tmpdirname,
device_map="auto",
no_split_module_classes=["BloomBlock"],
)
Expand Down

0 comments on commit a20ab0d

Please sign in to comment.