Skip to content

Commit

Permalink
[Refactor] Update nn inline_inbuilt check
Browse files Browse the repository at this point in the history
ghstack-source-id: 86c8a6dacd50387f76fd0a5b9ec9fd643b6d057f
Pull Request resolved: #1029
  • Loading branch information
vmoens committed Oct 4, 2024
1 parent 04faf40 commit d605e5c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
4 changes: 2 additions & 2 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2507,9 +2507,9 @@ def _lock_warn():


def _check_inbuild():
if not strtobool(os.environ.get("TORCHDYNAMO_INLINE_INBUILT_NN_MODULES", "0")):
if not torch._dynamo.config.inline_inbuilt_nn_modules:
raise RuntimeError(
"to_module requires TORCHDYNAMO_INLINE_INBUILT_NN_MODULES to be set."
"to_module requires torch._dynamo.config.inline_inbuilt_nn_modules to be set to True."
)


Expand Down
15 changes: 7 additions & 8 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import contextlib
import importlib.util
import inspect
import os
from pathlib import Path
from typing import Any, Callable

Expand Down Expand Up @@ -662,10 +661,10 @@ def test_dispatch_tensor(self, mode):
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
class TestFunctional:
def test_functional_error(self, mode):
TORCHDYNAMO_INLINE_INBUILT_NN_MODULES = os.environ.get(
"TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"
TORCHDYNAMO_INLINE_INBUILT_NN_MODULES = (
torch._dynamo.config.inline_inbuilt_nn_modules
)
os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1"
torch._dynamo.config.inline_inbuilt_nn_modules = True
module = torch.nn.Sequential(
torch.nn.Linear(3, 4),
torch.nn.ReLU(),
Expand All @@ -675,7 +674,7 @@ def test_functional_error(self, mode):
td_zero = TensorDictParams(td.data.clone())
td_zero.zero_()

os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "0"
torch._dynamo.config.inline_inbuilt_nn_modules = False
try:

def call(x, td):
Expand All @@ -685,12 +684,12 @@ def call(x, td):
call_compile = torch.compile(call, fullgraph=True, mode=mode)
x = torch.randn(2, 3)
with pytest.raises(
RuntimeError, match="TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"
RuntimeError, match="torch._dynamo.config.inline_inbuilt_nn_modules"
):
call_compile(x, td_zero)
finally:
if TORCHDYNAMO_INLINE_INBUILT_NN_MODULES is not None:
os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = (
if torch._dynamo.config.inline_inbuilt_nn_modules is not None:
torch._dynamo.config.inline_inbuilt_nn_modules = (
TORCHDYNAMO_INLINE_INBUILT_NN_MODULES
)

Expand Down

0 comments on commit d605e5c

Please sign in to comment.