Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 26, 2024
1 parent a9f09cc commit 4d57176
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 9 deletions.
6 changes: 6 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,18 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import multiprocessing
import os
import time
from collections import defaultdict

import pytest

try:
multiprocessing.set_start_method("spawn")
except Exception:
assert multiprocessing.get_start_method() == "spawn"

CALL_TIMES = defaultdict(lambda: 0.0)


Expand Down
2 changes: 1 addition & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@
),
]

mp_ctx = "fork" if (not torch.cuda.is_available() and not _IS_WINDOWS) else "spawn"
mp_ctx = "spawn"


@pytest.fixture
Expand Down
4 changes: 2 additions & 2 deletions tutorials/sphinx_tuto/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@

t0 = time.time()
model(x=x)
print(f"Time for TDModule: {(time.time()-t0)*1e6: 4.2f} micro-seconds")
print(f"Time for TDModule: {(time.time() - t0) * 1e6: 4.2f} micro-seconds")
exported = model_export.module()

# Exported version
t0 = time.time()
exported(x=x)
print(f"Time for exported module: {(time.time()-t0)*1e6: 4.2f} micro-seconds")
print(f"Time for exported module: {(time.time() - t0) * 1e6: 4.2f} micro-seconds")

##################################################
# and the FX graph:
Expand Down
12 changes: 6 additions & 6 deletions tutorials/sphinx_tuto/tensordict_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,33 +213,33 @@ def forward(self, x):
from torch.utils.benchmark import Timer

print(
f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
print(
f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
print(
f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)

print("Compiled versions")
block_notd_c = torch.compile(block_notd, mode="reduce-overhead")
for _ in range(5): # warmup
block_notd_c(x)
print(
f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
block_tdm_c = torch.compile(block_tdm, mode="reduce-overhead")
for _ in range(5): # warmup
block_tdm_c(x=x)
print(
f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
block_tds_c = torch.compile(block_tds, mode="reduce-overhead")
for _ in range(5): # warmup
block_tds_c(x=x)
print(
f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)

###############################################################################
Expand Down

0 comments on commit 4d57176

Please sign in to comment.