Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 4, 2025
1 parent 5630fc8 commit e49ea50
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/wheels-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
- name: Upload wheel for download
uses: actions/upload-artifact@v4
with:
name: tensordict-batch.whl
name: tensordict-win-${{ matrix.python_version[0] }}.whl
path: dist/*.whl

test-wheel-windows:
Expand Down
4 changes: 1 addition & 3 deletions tensordict/nn/distributions/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,7 @@ def maybe_deterministic_sample(dist):
if hasattr(dist, "deterministic_sample"):
return dist.deterministic_sample
else:
from tensordict.nn.probabilistic import (
DETERMINISTIC_REGISTER,
)
from tensordict.nn.probabilistic import DETERMINISTIC_REGISTER

# Fallbacks
tdist = type(dist)
Expand Down
7 changes: 7 additions & 0 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import contextlib
import importlib.util
import inspect
import platform
from pathlib import Path
from typing import Any, Callable

Expand Down Expand Up @@ -43,6 +44,8 @@

_v2_5 = TORCH_VERSION >= version.parse("2.5.0")

_IS_OSX = platform.system() == "Darwin"


def test_vmap_compile():
# Since we monkey patch vmap we need to make sure compile is happy with it
Expand Down Expand Up @@ -952,6 +955,10 @@ def to_numpy(tensor):
@pytest.mark.skipif(
TORCH_VERSION <= version.parse("2.4.1"), reason="requires torch>=2.5"
)
@pytest.mark.skipif(
(TORCH_VERSION <= version.parse("2.7.0")) and _IS_OSX,
reason="requires torch>=2.7 ons OSX",
)
@pytest.mark.parametrize("compiled", [False, True])
class TestCudaGraphs:
@pytest.fixture(scope="class", autouse=True)
Expand Down

0 comments on commit e49ea50

Please sign in to comment.