From e49ea50953ad45abe20878a70dcb8fb4449aafc1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 4 Feb 2025 18:13:34 +0000 Subject: [PATCH] init --- .github/workflows/wheels-windows.yml | 2 +- tensordict/nn/distributions/composite.py | 4 +--- test/test_compile.py | 7 +++++++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/.github/workflows/wheels-windows.yml b/.github/workflows/wheels-windows.yml index 15032eef3..204714567 100644 --- a/.github/workflows/wheels-windows.yml +++ b/.github/workflows/wheels-windows.yml @@ -43,7 +43,7 @@ jobs: - name: Upload wheel for download uses: actions/upload-artifact@v4 with: - name: tensordict-batch.whl + name: tensordict-win-${{ matrix.python_version[0] }}.whl path: dist/*.whl test-wheel-windows: diff --git a/tensordict/nn/distributions/composite.py b/tensordict/nn/distributions/composite.py index 53ec3441d..32b6d6ed3 100644 --- a/tensordict/nn/distributions/composite.py +++ b/tensordict/nn/distributions/composite.py @@ -304,9 +304,7 @@ def maybe_deterministic_sample(dist): if hasattr(dist, "deterministic_sample"): return dist.deterministic_sample else: - from tensordict.nn.probabilistic import ( - DETERMINISTIC_REGISTER, - ) + from tensordict.nn.probabilistic import DETERMINISTIC_REGISTER # Fallbacks tdist = type(dist) diff --git a/test/test_compile.py b/test/test_compile.py index 2d6792348..399ca9407 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -6,6 +6,7 @@ import contextlib import importlib.util import inspect +import platform from pathlib import Path from typing import Any, Callable @@ -43,6 +44,8 @@ _v2_5 = TORCH_VERSION >= version.parse("2.5.0") +_IS_OSX = platform.system() == "Darwin" + def test_vmap_compile(): # Since we monkey patch vmap we need to make sure compile is happy with it @@ -952,6 +955,10 @@ def to_numpy(tensor): @pytest.mark.skipif( TORCH_VERSION <= version.parse("2.4.1"), reason="requires torch>=2.5" ) +@pytest.mark.skipif( + (TORCH_VERSION <= version.parse("2.7.0")) and _IS_OSX, + reason="requires torch>=2.7 ons OSX", +) @pytest.mark.parametrize("compiled", [False, True]) class TestCudaGraphs: @pytest.fixture(scope="class", autouse=True)