Skip to content

Commit

Permalink
torch.testing.assert_equal didn't make it (#273)
Browse files Browse the repository at this point in the history
looks like pt-1.11 dropped `torch.testing.assert_equal`, so using `torch.testing.assert_equal` instead
  • Loading branch information
stas00 authored Mar 25, 2022
1 parent affff3d commit 87a9dba
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions megatron/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,9 @@ def get_gpu_count():
return 0

def torch_assert_equal(actual, expected, **kwargs):
# assert_equal was added around pt-1.9, it does better checks - e.g will check dimensions match
if hasattr(torch.testing, "assert_equal"):
return torch.testing.assert_equal(actual, expected, **kwargs)
# assert_close was added around pt-1.9, it does better checks - e.g will check dimensions match
if hasattr(torch.testing, "assert_close"):
return torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs)
else:
return torch.allclose(actual, expected, rtol=0.0, atol=0.0)

Expand Down Expand Up @@ -886,4 +886,4 @@ def flatten_arguments(args):
Example: {"arg1": "value1", "arg2": "value2"} -> ["IGNORED", "arg1", "value1", "arg2", "value2"]
"""
return ["IGNORED"] + [item for key_value in args.items() for item in key_value if item != ""]
return ["IGNORED"] + [item for key_value in args.items() for item in key_value if item != ""]

0 comments on commit 87a9dba

Please sign in to comment.