Skip to content

Commit

Permalink
Replace assert statements with exceptions (#24856)
Browse files Browse the repository at this point in the history
* Changed AssertionError to ValueError

try-except block was using AssesrtionError in except statement while the expected error is value error. Fixed the same.

* Changed AssertionError to ValueError

try-except block was using AssesrtionError in except statement while the expected error is ValueError. Fixed the same.
Note: While raising the ValueError args are passed to it, but later added again while handling the error (See the code snippet)

* Changed AssertionError to ValueError

try-except block was using AssesrtionError in except statement while the expected error is ValueError. Fixed the same.
Note: While raising the ValueError args are passed to it, but later added again while handling the error (See the code snippet)

* Changed AssertionError to ValueError

* Changed AssertionError to ValueError

* Changed AssertionError to ValueError

* Changed AssertionError to ValueError

* Changed AssertionError to ValueError

* Changed assert statement to ValueError based

* Changed assert statement to ValueError based

* Changed assert statement to ValueError based

* Changed incorrect error handling from AssertionError to ValueError

* Undoed change from AssertionError to ValueError as it is not needed

* Reverted back to using AssertionError as it is not necessary to make it into ValueError

* Fixed erraneous comparision

Changed == to !=

* Fixed erraneous comparision

Changed == to !=

* formatted the code

* Ran make fix-copies
  • Loading branch information
syedsalman137 authored Jul 17, 2023
1 parent 12b908c commit d015401
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
try:
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
except AssertionError as e:
except ValueError as e:
e.args += (pointer.shape, array.shape)
raise
print(f"Initialize PyTorch weight {name} from {original_name}")
Expand Down
17 changes: 9 additions & 8 deletions src/transformers/models/beit/convert_beit_unilm_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,18 +337,19 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
else:
raise ValueError("Can't verify logits as model is not supported")

assert logits.shape == expected_shape, "Shape of logits not as expected"
if logits.shape != expected_shape:
raise ValueError(f"Shape of logits not as expected. {logits.shape=}, {expected_shape=}")
if not has_lm_head:
if is_semantic:
assert torch.allclose(
logits[0, :3, :3, :3], expected_logits, atol=1e-3
), "First elements of logits not as expected"
if not torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3):
raise ValueError("First elements of logits not as expected")
else:
print("Predicted class idx:", logits.argmax(-1).item())
assert torch.allclose(
logits[0, :3], expected_logits, atol=1e-3
), "First elements of logits not as expected"
assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected"

if not torch.allclose(logits[0, :3], expected_logits, atol=1e-3):
raise ValueError("First elements of logits not as expected")
if logits.argmax(-1).item() != expected_class_idx:
raise ValueError("Predicted class index not as expected")

Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
try:
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
except AssertionError as e:
except ValueError as e:
e.args += (pointer.shape, array.shape)
raise
logger.info(f"Initialize PyTorch weight {name}")
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def load_tf_weights_trivia_qa(init_vars):
raise ValueError(
f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched of {txt_name}."
)
except AssertionError as e:
except ValueError as e:
e.args += (pointer.shape, array.shape)
raise
pt_weight_name = ".".join(pt_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,9 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
num = int(scope_names[1])
pointer = pointer[num]
try:
assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
except ValueError as e:
e.args += (pointer.shape, array.shape)
raise
logger.info(f"Initialize PyTorch weight {name}")
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_
try:
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
except AssertionError as e:
except ValueError as e:
e.args += (pointer.shape, array.shape)
raise
print(f"Initialize PyTorch weight {name}", original_name)
Expand Down
7 changes: 3 additions & 4 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,9 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
num = int(scope_names[1])
pointer = pointer[num]
try:
assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
except ValueError as e:
e.args += (pointer.shape, array.shape)
raise
logger.info(f"Initialize PyTorch weight {name}")
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/roc_bert/modeling_roc_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def load_tf_weights_in_roc_bert(model, config, tf_checkpoint_path):
try:
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
except AssertionError as e:
except ValueError as e:
e.args += (pointer.shape, array.shape)
raise
logger.info(f"Initialize PyTorch weight {name}")
Expand Down

0 comments on commit d015401

Please sign in to comment.