Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tf32 warnings #6816

Merged
merged 27 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
5f23835
rename precision doc
qingpeng9802 Aug 3, 2023
6014216
add `version_geq`
qingpeng9802 Aug 3, 2023
e271ae2
detect default tf32 settings
qingpeng9802 Aug 3, 2023
5d286aa
refactor `is_tf32_env()`
qingpeng9802 Aug 3, 2023
3b2345a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2023
35438b8
fix style E402
qingpeng9802 Aug 3, 2023
51f3270
Merge branch 'tf32-warnings' of https://github.com/qingpeng9802/MONAI…
qingpeng9802 Aug 3, 2023
5697a31
fix style E722
qingpeng9802 Aug 3, 2023
948ccc3
[MONAI] code formatting
monai-bot Aug 3, 2023
96ca146
refactor the usage of `detect_default_tf32()`
qingpeng9802 Aug 4, 2023
d8a65bb
improve `is_tf32_env()`
qingpeng9802 Aug 4, 2023
93ff777
[MONAI] code formatting
monai-bot Aug 4, 2023
0ad3a73
Merge branch 'dev' into tf32-warnings
wyli Aug 4, 2023
348f089
resolve `torch.cuda` initialization order issue
qingpeng9802 Aug 5, 2023
04b71c3
Merge branch 'tf32-warnings' of https://github.com/qingpeng9802/MONAI…
qingpeng9802 Aug 5, 2023
9a4310f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2023
a783158
[MONAI] code formatting
monai-bot Aug 5, 2023
13cd277
use `pynvml` to avoid `torch.cuda` call
qingpeng9802 Aug 7, 2023
fddc87f
minor fix
qingpeng9802 Aug 7, 2023
2a6ff88
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2023
e1b2075
[MONAI] code formatting
monai-bot Aug 7, 2023
73cb104
Merge branch 'dev' into tf32-warnings
wyli Aug 7, 2023
72c8a72
fix import `pynvml`
qingpeng9802 Aug 7, 2023
d1ed95b
Merge branch 'tf32-warnings' of https://github.com/qingpeng9802/MONAI…
qingpeng9802 Aug 7, 2023
055ed46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2023
f4571f1
[MONAI] code formatting
monai-bot Aug 7, 2023
147224c
Merge branch 'dev' into tf32-warnings
wyli Aug 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ Technical documentation is available at `docs.monai.io <https://docs.monai.io>`_

.. toctree::
:maxdepth: 1
:caption: Precision and Performance
:caption: Precision and Accelerating

precision_performance
precision_accelerating

.. toctree::
:maxdepth: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ by TF32 mode so the impact is very wide.
torch.backends.cuda.matmul.allow_tf32 = False # in PyTorch 1.12 and later.
torch.backends.cudnn.allow_tf32 = True
```
Please note that there are environment variables that can override the flags above. For example, the environment variables mentioned in [Accelerating AI Training with NVIDIA TF32 Tensor Cores](https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/) and `TORCH_ALLOW_TF32_CUBLAS_OVERRIDE` used by PyTorch. Thus, in some cases, the flags may be accidentally changed or overridden.

We recommend that users print out these two flags for confirmation when unsure.
Please note that there are environment variables that can override the flags above. For example, the environment variable `NVIDIA_TF32_OVERRIDE` mentioned in [Accelerating AI Training with NVIDIA TF32 Tensor Cores](https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/) and `TORCH_ALLOW_TF32_CUBLAS_OVERRIDE` used by PyTorch. Thus, in some cases, the flags may be accidentally changed or overridden.

If you are using an [NGC PyTorch container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch), the container includes a layer `ENV TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=1`.
The default value `torch.backends.cuda.matmul.allow_tf32` will be overridden to `True`.

We recommend that users print out these two flags for confirmation when unsure.

If you can confirm through experiments that your model has no accuracy or convergence issues in TF32 mode and you have NVIDIA Ampere GPUs or above, you can set the two flags above to `True` to speed up your model.
10 changes: 10 additions & 0 deletions monai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,13 @@
"utils",
"visualize",
]

try:
from .utils.tf32 import detect_default_tf32

detect_default_tf32()
except BaseException:
from .utils.misc import MONAIEnvVars

Check warning on line 87 in monai/__init__.py

View check run for this annotation

Codecov / codecov/patch

monai/__init__.py#L86-L87

Added lines #L86 - L87 were not covered by tests

if MONAIEnvVars.debug():
raise

Check warning on line 90 in monai/__init__.py

View check run for this annotation

Codecov / codecov/patch

monai/__init__.py#L89-L90

Added lines #L89 - L90 were not covered by tests
4 changes: 3 additions & 1 deletion monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# have to explicitly bring these in here to resolve circular import issues
from .aliases import alias, resolve_name
from .decorators import MethodReplacer, RestartGenerator
from .decorators import MethodReplacer, RestartGenerator, reset_torch_cuda_after_run
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather
from .enums import (
Expand Down Expand Up @@ -115,6 +115,7 @@
require_pkg,
run_debug,
run_eval,
version_geq,
version_leq,
)
from .nvtx import Range
Expand All @@ -128,6 +129,7 @@
torch_profiler_time_end_to_end,
)
from .state_cacher import StateCacher
from .tf32 import detect_default_tf32, has_ampere_or_later
from .type_conversion import (
convert_data_type,
convert_to_cupy,
Expand Down
28 changes: 27 additions & 1 deletion monai/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from functools import wraps

__all__ = ["RestartGenerator", "MethodReplacer"]
__all__ = ["RestartGenerator", "MethodReplacer", "reset_torch_cuda_after_run"]

from typing import Callable, Generator

Expand Down Expand Up @@ -80,3 +80,29 @@ def newinit(_self, *args, **kwargs):
namelist.append(entry)

setattr(owner, name, self.meth)


def reset_torch_cuda_after_run(func: Callable) -> Callable:
"""
To resolve `torch.cuda` initialization order issue.
If a function calls `torch.cuda` in a `__init__.py` file,
the function should be decorated by this decorator
to maintain CUDA lazy initialization.

See https://github.com/Project-MONAI/MONAI/issues/2161 and
https://github.com/pytorch/pytorch/issues/80876
"""

@wraps(func)
def wrapped_func(*args, **kwargs):
val = func(*args, **kwargs)

import importlib

import torch

importlib.reload(torch.cuda)

return val

return wrapped_func
68 changes: 52 additions & 16 deletions monai/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pydoc import locate
from re import match
from types import FunctionType, ModuleType
from typing import Any, cast
from typing import Any, Iterable, cast

import torch

Expand Down Expand Up @@ -55,6 +55,7 @@
"get_package_version",
"get_torch_version_tuple",
"version_leq",
"version_geq",
"pytorch_after",
]

Expand Down Expand Up @@ -518,24 +519,11 @@
return tuple(int(x) for x in torch.__version__.split(".")[:2])


def version_leq(lhs: str, rhs: str) -> bool:
def parse_version_strs(lhs: str, rhs: str) -> tuple[Iterable[int | str], Iterable[int | str]]:
"""
Returns True if version `lhs` is earlier or equal to `rhs`.

Args:
lhs: version name to compare with `rhs`, return True if earlier or equal to `rhs`.
rhs: version name to compare with `lhs`, return True if later or equal to `lhs`.

Parse the version strings.
"""

lhs, rhs = str(lhs), str(rhs)
pkging, has_ver = optional_import("pkg_resources", name="packaging")
if has_ver:
try:
return cast(bool, pkging.version.Version(lhs) <= pkging.version.Version(rhs))
except pkging.version.InvalidVersion:
return True

def _try_cast(val: str) -> int | str:
val = val.strip()
try:
Expand All @@ -554,7 +542,28 @@
# parse the version strings in this basic way without `packaging` package
lhs_ = map(_try_cast, lhs.split("."))
rhs_ = map(_try_cast, rhs.split("."))
return lhs_, rhs_

Check warning on line 545 in monai/utils/module.py

View check run for this annotation

Codecov / codecov/patch

monai/utils/module.py#L545

Added line #L545 was not covered by tests


def version_leq(lhs: str, rhs: str) -> bool:
"""
Returns True if version `lhs` is earlier or equal to `rhs`.

Args:
lhs: version name to compare with `rhs`, return True if earlier or equal to `rhs`.
rhs: version name to compare with `lhs`, return True if later or equal to `lhs`.

"""

lhs, rhs = str(lhs), str(rhs)
pkging, has_ver = optional_import("pkg_resources", name="packaging")
if has_ver:
try:
return cast(bool, pkging.version.Version(lhs) <= pkging.version.Version(rhs))
except pkging.version.InvalidVersion:
return True

lhs_, rhs_ = parse_version_strs(lhs, rhs)

Check warning on line 566 in monai/utils/module.py

View check run for this annotation

Codecov / codecov/patch

monai/utils/module.py#L566

Added line #L566 was not covered by tests
for l, r in zip(lhs_, rhs_):
if l != r:
if isinstance(l, int) and isinstance(r, int):
Expand All @@ -564,6 +573,33 @@
return True


def version_geq(lhs: str, rhs: str) -> bool:
"""
Returns True if version `lhs` is later or equal to `rhs`.

Args:
lhs: version name to compare with `rhs`, return True if later or equal to `rhs`.
rhs: version name to compare with `lhs`, return True if earlier or equal to `lhs`.

"""
lhs, rhs = str(lhs), str(rhs)
pkging, has_ver = optional_import("pkg_resources", name="packaging")
if has_ver:
try:
return cast(bool, pkging.version.Version(lhs) >= pkging.version.Version(rhs))
except pkging.version.InvalidVersion:
return True

lhs_, rhs_ = parse_version_strs(lhs, rhs)
for l, r in zip(lhs_, rhs_):
if l != r:
if isinstance(l, int) and isinstance(r, int):
return l > r
return f"{l}" > f"{r}"

Check warning on line 598 in monai/utils/module.py

View check run for this annotation

Codecov / codecov/patch

monai/utils/module.py#L593-L598

Added lines #L593 - L598 were not covered by tests

return True

Check warning on line 600 in monai/utils/module.py

View check run for this annotation

Codecov / codecov/patch

monai/utils/module.py#L600

Added line #L600 was not covered by tests


@functools.lru_cache(None)
def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: str | None = None) -> bool:
"""
Expand Down
79 changes: 79 additions & 0 deletions monai/utils/tf32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import functools
import os
import warnings

from monai.utils.decorators import reset_torch_cuda_after_run

__all__ = ["has_ampere_or_later", "detect_default_tf32"]


@functools.lru_cache(None)
@reset_torch_cuda_after_run
def has_ampere_or_later() -> bool:
"""
Check if there is any Ampere and later GPU.
"""
import torch

from monai.utils.module import version_geq

if not version_geq(f"{torch.version.cuda}", "11.0"):
return False

Check warning on line 34 in monai/utils/tf32.py

View check run for this annotation

Codecov / codecov/patch

monai/utils/tf32.py#L34

Added line #L34 was not covered by tests
for i in range(torch.cuda.device_count()):
major, _ = torch.cuda.get_device_capability(i)
if major >= 8: # Ampere and later
return True

Check warning on line 38 in monai/utils/tf32.py

View check run for this annotation

Codecov / codecov/patch

monai/utils/tf32.py#L38

Added line #L38 was not covered by tests
wyli marked this conversation as resolved.
Show resolved Hide resolved
return False


@functools.lru_cache(None)
def detect_default_tf32() -> bool:
qingpeng9802 marked this conversation as resolved.
Show resolved Hide resolved
"""
Dectect if there is anything that may enable TF32 mode by default.
If any, show a warning message.
"""
may_enable_tf32 = False
try:
if not has_ampere_or_later():
return False

from monai.utils.module import pytorch_after

Check warning on line 53 in monai/utils/tf32.py

View check run for this annotation

Codecov / codecov/patch

monai/utils/tf32.py#L53

Added line #L53 was not covered by tests

if pytorch_after(1, 7, 0) and not pytorch_after(1, 12, 0):
warnings.warn(

Check warning on line 56 in monai/utils/tf32.py

View check run for this annotation

Codecov / codecov/patch

monai/utils/tf32.py#L55-L56

Added lines #L55 - L56 were not covered by tests
"torch.backends.cuda.matmul.allow_tf32 = True by default.\n"
" This value defaults to True when PyTorch version in [1.7, 1.11] and may affect precision.\n"
" See https://docs.monai.io/en/latest/precision_accelerating.html#precision-and-accelerating"
)
may_enable_tf32 = True

Check warning on line 61 in monai/utils/tf32.py

View check run for this annotation

Codecov / codecov/patch

monai/utils/tf32.py#L61

Added line #L61 was not covered by tests

override_tf32_env_vars = {"NVIDIA_TF32_OVERRIDE": "1", "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE": "1"}
for name, override_val in override_tf32_env_vars.items():
if os.environ.get(name) == override_val:
warnings.warn(

Check warning on line 66 in monai/utils/tf32.py

View check run for this annotation

Codecov / codecov/patch

monai/utils/tf32.py#L63-L66

Added lines #L63 - L66 were not covered by tests
f"Environment variable `{name} = {override_val}` is set.\n"
f" This environment variable may enable TF32 mode accidentally and affect precision.\n"
f" See https://docs.monai.io/en/latest/precision_accelerating.html#precision-and-accelerating"
)
may_enable_tf32 = True

Check warning on line 71 in monai/utils/tf32.py

View check run for this annotation

Codecov / codecov/patch

monai/utils/tf32.py#L71

Added line #L71 was not covered by tests

return may_enable_tf32
except BaseException:
from monai.utils.misc import MONAIEnvVars

Check warning on line 75 in monai/utils/tf32.py

View check run for this annotation

Codecov / codecov/patch

monai/utils/tf32.py#L73-L75

Added lines #L73 - L75 were not covered by tests

if MONAIEnvVars.debug():
raise
return False

Check warning on line 79 in monai/utils/tf32.py

View check run for this annotation

Codecov / codecov/patch

monai/utils/tf32.py#L77-L79

Added lines #L77 - L79 were not covered by tests
9 changes: 7 additions & 2 deletions tests/test_version_leq.py → tests/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from parameterized import parameterized

from monai.utils import version_leq
from monai.utils import version_geq, version_leq


# from pkg_resources
Expand Down Expand Up @@ -76,10 +76,15 @@ def _pairwise(iterable):

class TestVersionCompare(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_compare(self, a, b, expected=True):
def test_compare_leq(self, a, b, expected=True):
"""Test version_leq with `a` and `b`"""
self.assertEqual(version_leq(a, b), expected)

@parameterized.expand(TEST_CASES)
def test_compare_geq(self, a, b, expected=True):
"""Test version_geq with `b` and `a`"""
self.assertEqual(version_geq(b, a), expected)


if __name__ == "__main__":
unittest.main()
16 changes: 6 additions & 10 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
from monai.data.meta_tensor import MetaTensor, get_track_meta
from monai.networks import convert_to_onnx, convert_to_torchscript
from monai.utils import optional_import
from monai.utils.module import pytorch_after, version_leq
from monai.utils.module import pytorch_after
from monai.utils.tf32 import detect_default_tf32
from monai.utils.type_conversion import convert_data_type

nib, _ = optional_import("nibabel")
Expand Down Expand Up @@ -172,19 +173,14 @@ def test_is_quick():

def is_tf32_env():
"""
The environment variable NVIDIA_TF32_OVERRIDE=0 will override any defaults
or programmatic configuration of NVIDIA libraries, and consequently,
cuBLAS will not accelerate FP32 computations with TF32 tensor cores.
When we may be using TF32 mode, check the precision of matrix operation.
If the checking result is greater than the threshold 0.001,
set _tf32_enabled=True (and relax _rtol for tests).
"""
global _tf32_enabled
qingpeng9802 marked this conversation as resolved.
Show resolved Hide resolved
if _tf32_enabled is None:
_tf32_enabled = False
if (
torch.cuda.is_available()
and not version_leq(f"{torch.version.cuda}", "10.100")
and os.environ.get("NVIDIA_TF32_OVERRIDE", "1") != "0"
and torch.cuda.device_count() > 0 # at least 11.0
):
if detect_default_tf32() or torch.backends.cuda.matmul.allow_tf32:
try:
# with TF32 enabled, the speed is ~8x faster, but the precision has ~2 digits less in the result
g_gpu = torch.Generator(device="cuda")
Expand Down
Loading