Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace assert statements with exceptions #24856

Merged
merged 22 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
c676be1
Changed AssertionError to ValueError
syedsalman137 Jul 17, 2023
685f609
Changed AssertionError to ValueError
syedsalman137 Jul 17, 2023
6edd21a
Changed AssertionError to ValueError
syedsalman137 Jul 17, 2023
5b957a5
Changed AssertionError to ValueError
syedsalman137 Jul 17, 2023
f87cd74
Changed AssertionError to ValueError
syedsalman137 Jul 17, 2023
740c537
Changed AssertionError to ValueError
syedsalman137 Jul 17, 2023
6fb512d
Changed AssertionError to ValueError
syedsalman137 Jul 17, 2023
5ca9cbb
Changed AssertionError to ValueError
syedsalman137 Jul 17, 2023
6eda2e0
Changed assert statement to ValueError based
syedsalman137 Jul 17, 2023
5d91984
Changed assert statement to ValueError based
syedsalman137 Jul 17, 2023
62dc812
Changed assert statement to ValueError based
syedsalman137 Jul 17, 2023
5c2b0eb
Merge branch 'huggingface:main' into main
syedsalman137 Jul 17, 2023
4b437b8
Changed incorrect error handling from AssertionError to ValueError
syedsalman137 Jul 17, 2023
060b6f9
Merge branch 'main' of https://github.com/SalmanHabeeb/transformers
syedsalman137 Jul 17, 2023
0faf7ab
Undoed change from AssertionError to ValueError as it is not needed
syedsalman137 Jul 17, 2023
3a3bf82
Reverted back to using AssertionError as it is not necessary to make …
syedsalman137 Jul 17, 2023
53d5933
Merge branch 'huggingface:main' into main
syedsalman137 Jul 17, 2023
66f8e64
Fixed erraneous comparision
syedsalman137 Jul 17, 2023
614acfb
Merge branch 'main' of https://github.com/SalmanHabeeb/transformers
syedsalman137 Jul 17, 2023
e810357
Fixed erraneous comparision
syedsalman137 Jul 17, 2023
e9fa1cf
formatted the code
syedsalman137 Jul 17, 2023
d3977a3
Ran make fix-copies
syedsalman137 Jul 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
try:
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
except AssertionError as e:
except ValueError as e:
e.args += (pointer.shape, array.shape)
raise
print(f"Initialize PyTorch weight {name} from {original_name}")
Expand Down
17 changes: 9 additions & 8 deletions src/transformers/models/beit/convert_beit_unilm_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,18 +337,19 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
else:
raise ValueError("Can't verify logits as model is not supported")

assert logits.shape == expected_shape, "Shape of logits not as expected"
if logits.shape != expected_shape:
raise ValueError(f"Shape of logits not as expected. {logits.shape=}, {expected_shape=}")
if not has_lm_head:
if is_semantic:
assert torch.allclose(
logits[0, :3, :3, :3], expected_logits, atol=1e-3
), "First elements of logits not as expected"
if not torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3):
raise ValueError("First elements of logits not as expected")
else:
print("Predicted class idx:", logits.argmax(-1).item())
assert torch.allclose(
logits[0, :3], expected_logits, atol=1e-3
), "First elements of logits not as expected"
assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected"

if not torch.allclose(logits[0, :3], expected_logits, atol=1e-3):
raise ValueError("First elements of logits not as expected")
if logits.argmax(-1).item() != expected_class_idx:
raise ValueError("Predicted class index not as expected")

Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
try:
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
except AssertionError as e:
except ValueError as e:
e.args += (pointer.shape, array.shape)
raise
logger.info(f"Initialize PyTorch weight {name}")
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def load_tf_weights_trivia_qa(init_vars):
raise ValueError(
f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched of {txt_name}."
)
except AssertionError as e:
except ValueError as e:
e.args += (pointer.shape, array.shape)
raise
pt_weight_name = ".".join(pt_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,9 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
num = int(scope_names[1])
pointer = pointer[num]
try:
assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
except ValueError as e:
e.args += (pointer.shape, array.shape)
raise
logger.info(f"Initialize PyTorch weight {name}")
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_
try:
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
except AssertionError as e:
except ValueError as e:
e.args += (pointer.shape, array.shape)
raise
print(f"Initialize PyTorch weight {name}", original_name)
Expand Down
7 changes: 3 additions & 4 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,9 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
num = int(scope_names[1])
pointer = pointer[num]
try:
assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
except ValueError as e:
e.args += (pointer.shape, array.shape)
raise
logger.info(f"Initialize PyTorch weight {name}")
Expand Down