Skip to content

Commit

Permalink
[Benchmark] Benchmark to_module (#669)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 7, 2024
1 parent 751091a commit ca92d20
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions benchmarks/nn/functional_benchmarks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,19 @@ def fun(x, params):
benchmark(vfun, x, params)


@pytest.mark.parametrize("tdparams", [True, False])
def test_to_module_speed(benchmark, tdparams):
module = torch.nn.Transformer()
params = TensorDict.from_module(module, as_module=tdparams)

def func(params=params, module=module):
with params.to_module(module):
pass
return

benchmark(func)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 comments on commit ca92d20

Please sign in to comment.