diff --git a/megatron/testing_utils.py b/megatron/testing_utils.py index 2143b610b..9521cb361 100644 --- a/megatron/testing_utils.py +++ b/megatron/testing_utils.py @@ -232,9 +232,9 @@ def get_gpu_count(): return 0 def torch_assert_equal(actual, expected, **kwargs): - # assert_equal was added around pt-1.9, it does better checks - e.g will check dimensions match - if hasattr(torch.testing, "assert_equal"): - return torch.testing.assert_equal(actual, expected, **kwargs) + # assert_close was added around pt-1.9, it does better checks - e.g will check dimensions match + if hasattr(torch.testing, "assert_close"): + return torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs) else: return torch.allclose(actual, expected, rtol=0.0, atol=0.0) @@ -886,4 +886,4 @@ def flatten_arguments(args): Example: {"arg1": "value1", "arg2": "value2"} -> ["IGNORED", "arg1", "value1", "arg2", "value2"] """ - return ["IGNORED"] + [item for key_value in args.items() for item in key_value if item != ""] \ No newline at end of file + return ["IGNORED"] + [item for key_value in args.items() for item in key_value if item != ""]