Skip to content

Commit

Permalink
bump: Torch 2.5 (#20351)
Browse files Browse the repository at this point in the history
* bump: Torch `2.5.0`

* push docker

* docker

* 2.5.1 and mypy

* update USE_DISTRIBUTED=0 test

* also for pytorch lightning no distributed

* set USE_LIBUV=0 on windows

* try drop pickle warning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* disable compiling update_metrics

* bump 2.2.x to bugfix

* disable also log in logger connector (also calls metric)

* more point release bumps

* remove unloved type ignore and print some more on exit

* update checkgroup

* minor versions

* shortened version in build-pl

* pytorch 2.4 is with python 3.11

* 2.1 and 2.3 without patch release

* for 2.4.1: docker with 3.11 test with 3.12

---------

Co-authored-by: Thomas Viehmann <tv.code@beamnet.de>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
(cherry picked from commit 61a403a)
  • Loading branch information
Borda committed Nov 12, 2024
1 parent 24478af commit dabbf83
Show file tree
Hide file tree
Showing 26 changed files with 117 additions and 88 deletions.
2 changes: 1 addition & 1 deletion .azure/gpu-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
variables:
DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' )
container:
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.4-cuda12.1.0"
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.5-cuda12.1.0"
options: "--gpus=all --shm-size=32g"
strategy:
matrix:
Expand Down
2 changes: 1 addition & 1 deletion .azure/gpu-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0"
PACKAGE_NAME: "fabric"
"Lightning | latest":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.4-cuda12.1.0"
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.5-cuda12.1.0"
PACKAGE_NAME: "lightning"
workspace:
clean: all
Expand Down
2 changes: 1 addition & 1 deletion .azure/gpu-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0"
PACKAGE_NAME: "pytorch"
"Lightning | latest":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.4-cuda12.1.0"
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.5-cuda12.1.0"
PACKAGE_NAME: "lightning"
pool: lit-rtx-3090
variables:
Expand Down
44 changes: 26 additions & 18 deletions .github/checkgroup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,22 @@ subprojects:
checks:
- "pl-cpu (macOS-13, lightning, 3.9, 2.1, oldest)"
- "pl-cpu (macOS-14, lightning, 3.10, 2.1)"
- "pl-cpu (macOS-14, lightning, 3.11, 2.2)"
- "pl-cpu (macOS-14, lightning, 3.11, 2.2.2)"
- "pl-cpu (macOS-14, lightning, 3.11, 2.3)"
- "pl-cpu (macOS-14, lightning, 3.12, 2.4)"
- "pl-cpu (macOS-14, lightning, 3.12, 2.4.1)"
- "pl-cpu (macOS-14, lightning, 3.12, 2.5.1)"
- "pl-cpu (ubuntu-20.04, lightning, 3.9, 2.1, oldest)"
- "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1)"
- "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.2)"
- "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.2.2)"
- "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.3)"
- "pl-cpu (ubuntu-20.04, lightning, 3.12, 2.4)"
- "pl-cpu (ubuntu-22.04, lightning, 3.12, 2.4.1)"
- "pl-cpu (ubuntu-22.04, lightning, 3.12, 2.5.1)"
- "pl-cpu (windows-2022, lightning, 3.9, 2.1, oldest)"
- "pl-cpu (windows-2022, lightning, 3.10, 2.1)"
- "pl-cpu (windows-2022, lightning, 3.11, 2.2)"
- "pl-cpu (windows-2022, lightning, 3.11, 2.2.2)"
- "pl-cpu (windows-2022, lightning, 3.11, 2.3)"
- "pl-cpu (windows-2022, lightning, 3.12, 2.4)"
- "pl-cpu (windows-2022, lightning, 3.12, 2.4.1)"
- "pl-cpu (windows-2022, lightning, 3.12, 2.5.1)"
- "pl-cpu (macOS-14, pytorch, 3.9, 2.1)"
- "pl-cpu (ubuntu-20.04, pytorch, 3.9, 2.1)"
- "pl-cpu (windows-2022, pytorch, 3.9, 2.1)"
Expand Down Expand Up @@ -141,15 +144,17 @@ subprojects:
- "!*.md"
- "!**/*.md"
checks:
- "build-cuda (3.11, 2.1, 12.1.0)"
- "build-cuda (3.11, 2.2, 12.1.0)"
- "build-cuda (3.11, 2.3, 12.1.0)"
- "build-cuda (3.12, 2.4, 12.1.0)"
- "build-cuda (3.10, 2.1.2, 12.1.0)"
- "build-cuda (3.11, 2.2.2, 12.1.0)"
- "build-cuda (3.11, 2.3.1, 12.1.0)"
- "build-cuda (3.11, 2.4.1, 12.1.0)"
- "build-cuda (3.12, 2.5.1, 12.1.0)"
#- "build-NGC"
- "build-pl (3.11, 2.1, 12.1.0)"
- "build-pl (3.10, 2.1, 12.1.0)"
- "build-pl (3.11, 2.2, 12.1.0)"
- "build-pl (3.11, 2.3, 12.1.0)"
- "build-pl (3.12, 2.4, 12.1.0)"
- "build-pl (3.11, 2.4, 12.1.0)"
- "build-pl (3.12, 2.5, 12.1.0)"

# SECTION: lightning_fabric

Expand All @@ -168,19 +173,22 @@ subprojects:
checks:
- "fabric-cpu (macOS-13, lightning, 3.9, 2.1, oldest)"
- "fabric-cpu (macOS-14, lightning, 3.10, 2.1)"
- "fabric-cpu (macOS-14, lightning, 3.11, 2.2)"
- "fabric-cpu (macOS-14, lightning, 3.11, 2.2.2)"
- "fabric-cpu (macOS-14, lightning, 3.11, 2.3)"
- "fabric-cpu (macOS-14, lightning, 3.12, 2.4)"
- "fabric-cpu (macOS-14, lightning, 3.12, 2.4.1)"
- "fabric-cpu (macOS-14, lightning, 3.12, 2.5.1)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.9, 2.1, oldest)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.1)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2.2)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.3)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.12, 2.4)"
- "fabric-cpu (ubuntu-22.04, lightning, 3.12, 2.4.1)"
- "fabric-cpu (ubuntu-22.04, lightning, 3.12, 2.5.1)"
- "fabric-cpu (windows-2022, lightning, 3.9, 2.1, oldest)"
- "fabric-cpu (windows-2022, lightning, 3.10, 2.1)"
- "fabric-cpu (windows-2022, lightning, 3.11, 2.2)"
- "fabric-cpu (windows-2022, lightning, 3.11, 2.2.2)"
- "fabric-cpu (windows-2022, lightning, 3.11, 2.3)"
- "fabric-cpu (windows-2022, lightning, 3.12, 2.4)"
- "fabric-cpu (windows-2022, lightning, 3.12, 2.4.1)"
- "fabric-cpu (windows-2022, lightning, 3.12, 2.5.1)"
- "fabric-cpu (macOS-14, fabric, 3.9, 2.1)"
- "fabric-cpu (ubuntu-20.04, fabric, 3.9, 2.1)"
- "fabric-cpu (windows-2022, fabric, 3.9, 2.1)"
Expand Down
15 changes: 9 additions & 6 deletions .github/workflows/ci-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,18 @@ jobs:
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
- { os: "macOS-13", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.1" }
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.1" }
Expand Down
15 changes: 9 additions & 6 deletions .github/workflows/ci-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,18 @@ jobs:
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" }
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
- { os: "macOS-13", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.1" }
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.1" }
Expand Down
22 changes: 15 additions & 7 deletions .github/workflows/docker-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@ jobs:
include:
# We only release one docker image per PyTorch version.
# Make sure the matrix here matches the one below.
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
- { python_version: "3.10", pytorch_version: "2.1", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" }
- { python_version: "3.12", pytorch_version: "2.4", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.4", cuda_version: "12.1.0" }
- { python_version: "3.12", pytorch_version: "2.5", cuda_version: "12.1.0" }
steps:
- uses: actions/checkout@v4
with:
Expand Down Expand Up @@ -103,10 +104,11 @@ jobs:
include:
# These are the base images for PL release docker images.
# Make sure the matrix here matches the one above.
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" }
- { python_version: "3.12", pytorch_version: "2.4", cuda_version: "12.1.0" }
- { python_version: "3.10", pytorch_version: "2.1.2", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.2.2", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.3.1", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.4.1", cuda_version: "12.1.0" }
- { python_version: "3.12", pytorch_version: "2.5.1", cuda_version: "12.1.0" }
steps:
- uses: actions/checkout@v4
- uses: docker/setup-buildx-action@v3
Expand All @@ -115,6 +117,12 @@ jobs:
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}

- name: shorten Torch version
run: |
# convert 1.10.2 to 1.10
pt_version=$(echo ${{ matrix.pytorch_version }} | cut -d. -f1,2)
echo "PT_VERSION=$pt_version" >> $GITHUB_ENV
- uses: docker/build-push-action@v6
with:
build-args: |
Expand All @@ -123,7 +131,7 @@ jobs:
CUDA_VERSION=${{ matrix.cuda_version }}
file: dockers/base-cuda/Dockerfile
push: ${{ env.PUSH_NIGHTLY }}
tags: "pytorchlightning/pytorch_lightning:base-cuda-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}-cuda${{ matrix.cuda_version }}"
tags: "pytorchlightning/pytorch_lightning:base-cuda-py${{ matrix.python_version }}-torch${{ env.PT_VERSION }}-cuda${{ matrix.cuda_version }}"
timeout-minutes: 95
- uses: ravsamhq/notify-slack-action@v2
if: failure() && env.PUSH_NIGHTLY == 'true'
Expand Down
2 changes: 1 addition & 1 deletion requirements/fabric/base.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torch >=2.1.0, <2.5.0
torch >=2.1.0, <2.6.0
fsspec[http] >=2022.5.0, <2024.4.0
packaging >=20.0, <=23.1
typing-extensions >=4.4.0, <4.10.0
Expand Down
4 changes: 2 additions & 2 deletions requirements/fabric/examples.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torchvision >=0.16.0, <0.20.0
torchmetrics >=0.10.0, <1.3.0
torchvision >=0.16.0, <0.21.0
torchmetrics >=0.10.0, <1.5.0
lightning-utilities >=0.8.0, <0.12.0
2 changes: 1 addition & 1 deletion requirements/fabric/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ pytest-rerunfailures ==12.0
pytest-random-order ==1.1.0
click ==8.1.7
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version
torchmetrics >=0.7.0, <1.5.0 # needed for using fixed compare_version
4 changes: 2 additions & 2 deletions requirements/pytorch/base.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torch >=2.1.0, <2.5.0
torch >=2.1.0, <2.6.0
tqdm >=4.57.0, <4.67.0
PyYAML >=5.4, <6.1.0
fsspec[http] >=2022.5.0, <2024.4.0
torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version
torchmetrics >=0.7.0, <1.5.0 # needed for using fixed compare_version
packaging >=20.0, <=23.1
typing-extensions >=4.4.0, <4.10.0
lightning-utilities >=0.10.0, <0.12.0
4 changes: 2 additions & 2 deletions requirements/pytorch/examples.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

requests <2.32.0
torchvision >=0.16.0, <0.20.0
torchvision >=0.16.0, <0.21.0
ipython[all] <8.15.0
torchmetrics >=0.10.0, <1.3.0
torchmetrics >=0.10.0, <1.5.0
lightning-utilities >=0.8.0, <0.12.0
2 changes: 1 addition & 1 deletion requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mypy==1.11.0
torch==2.4.1
torch==2.5.1

types-Markdown
types-PyYAML
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/fabric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import os
import sys

from lightning_utilities.core.imports import package_available

Expand All @@ -26,6 +27,10 @@
# https://github.com/pytorch/pytorch/issues/83973
os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1"

# see https://github.com/pytorch/pytorch/issues/139990
if sys.platform == "win32":
os.environ["USE_LIBUV"] = "0"


from lightning.fabric.fabric import Fabric # noqa: E402
from lightning.fabric.utilities.seed import seed_everything # noqa: E402
Expand Down
6 changes: 4 additions & 2 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def log(
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx, # type: ignore[arg-type]
reduce_fx=reduce_fx,
enable_graph=enable_graph,
add_dataloader_idx=add_dataloader_idx,
batch_size=batch_size,
Expand Down Expand Up @@ -1405,7 +1405,9 @@ def forward(self, x):
input_sample = self._apply_batch_transfer_handler(input_sample)

file_path = str(file_path) if isinstance(file_path, Path) else file_path
torch.onnx.export(self, input_sample, file_path, **kwargs)
# PyTorch (2.5) declares file_path to be str | PathLike[Any] | None, but
# BytesIO does work, too.
torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore
self.train(mode)

@torch.no_grad()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def _extract_batch_size(self, value: _ResultMetric, batch_size: Optional[int], m

return batch_size

@torch.compiler.disable
def log(
self,
fx: str,
Expand Down Expand Up @@ -413,6 +414,7 @@ def log(
batch_size = self._extract_batch_size(self[key], batch_size, meta)
self.update_metrics(key, value, batch_size)

@torch.compiler.disable
def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None:
result_metric = self[key]
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
Expand Down
1 change: 1 addition & 0 deletions tests/run_standalone_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ function show_batched_output {
# heuristic: stop if there's mentions of errors. this can prevent false negatives when only some of the ranks fail
if perl -nle 'print if /error|(?<!(?-i)on_)exception|traceback|(?<!(?-i)x)failed/i' standalone_test_output.txt | grep -qv -f testnames.txt; then
echo "Potential error! Stopping."
perl -nle 'print if /error|(?<!(?-i)on_)exception|traceback|(?<!(?-i)x)failed/i' standalone_test_output.txt
rm standalone_test_output.txt
exit 1
fi
Expand Down
12 changes: 12 additions & 0 deletions tests/tests_fabric/utilities/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ def test_import_fabric_with_torch_dist_unavailable():
code = dedent(
"""
import torch
try:
# PyTorch 2.5 relies on torch,distributed._composable.fsdp not
# existing with USE_DISTRIBUTED=0
import torch._dynamo.variables.functions
torch._dynamo.variables.functions._fsdp_param_group = None
except ImportError:
pass
# pretend torch.distributed not available
for name in list(torch.distributed.__dict__.keys()):
Expand All @@ -31,6 +38,11 @@ def test_import_fabric_with_torch_dist_unavailable():
torch.distributed.is_available = lambda: False
# needed for Dynamo in PT 2.5+ compare the torch.distributed source
class _ProcessGroupStub:
pass
torch.distributed.ProcessGroup = _ProcessGroupStub
import lightning.fabric
"""
)
Expand Down
8 changes: 2 additions & 6 deletions tests/tests_pytorch/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@
import math
import os
import pickle
from contextlib import nullcontext
from typing import List, Optional
from unittest import mock
from unittest.mock import Mock

import cloudpickle
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
Expand Down Expand Up @@ -193,13 +191,11 @@ def test_pickling():
early_stopping = EarlyStopping(monitor="foo")

early_stopping_pickled = pickle.dumps(early_stopping)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
early_stopping_loaded = pickle.loads(early_stopping_pickled)
early_stopping_loaded = pickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)

early_stopping_pickled = cloudpickle.dumps(early_stopping)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)


Expand Down
Loading

0 comments on commit dabbf83

Please sign in to comment.