From 3e976ac49a498c936acffdca7e8077ff77c97a8f Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 21 Nov 2023 17:54:09 +0000 Subject: [PATCH 1/6] init --- benchmarks/nn/functional_benchmarks_test.py | 111 ++++++++++++++++++++ tensordict/_lazy.py | 2 + tensordict/_td.py | 47 +++++---- tensordict/base.py | 8 +- tensordict/nn/common.py | 2 +- tensordict/utils.py | 9 +- test/test_nn.py | 82 +++++++++++++++ 7 files changed, 236 insertions(+), 25 deletions(-) diff --git a/benchmarks/nn/functional_benchmarks_test.py b/benchmarks/nn/functional_benchmarks_test.py index 096b3901b..0308797eb 100644 --- a/benchmarks/nn/functional_benchmarks_test.py +++ b/benchmarks/nn/functional_benchmarks_test.py @@ -120,6 +120,16 @@ def test_instantiation_td(benchmark, net): # Execution def test_exec_functorch(benchmark, net): + x = torch.randn(2, 2) + sd = net.state_dict() + + def fun(x, sd): + torch.func.functional_call(net, sd, x) + + benchmark(fun, x, sd) + + +def test_exec_functional_call(benchmark, net): x = torch.randn(2, 2) fmodule, params, buffers = functorch_make_functional(net) benchmark(fmodule, params, buffers, x) @@ -132,6 +142,18 @@ def test_exec_td(benchmark, net): benchmark(fmodule, x, params=params) +def test_exec_td_decorator(benchmark, net): + x = torch.randn(2, 2) + fmodule = net + params = TensorDict.from_module(fmodule) + + def fun(x, params): + with params.to_module(net): + net(x) + + benchmark(fun, x, params) + + @torch.no_grad() @pytest.mark.parametrize("stack", [True, False]) @pytest.mark.parametrize("tdmodule", [True, False]) @@ -169,6 +191,48 @@ def test_vmap_mlp_speed(benchmark, stack, tdmodule): benchmark(fun, x, params) +@torch.no_grad() +@pytest.mark.parametrize("stack", [True, False]) +@pytest.mark.parametrize("tdmodule", [True, False]) +def test_vmap_mlp_speed_decorator(benchmark, stack, tdmodule): + # tests speed of vmapping over a transformer + device = "cuda" if torch.cuda.device_count() else "cpu" + t = nn.Sequential( + nn.Linear(64, 64, device=device), + nn.ReLU(), + nn.Linear(64, 64, device=device), + nn.ReLU(), + nn.Linear(64, 64, device=device), + nn.ReLU(), + nn.Linear(64, 64, device=device), + nn.ReLU(), + ) + if tdmodule: + t = TensorDictModule(t, in_keys=["x"], out_keys=["y"]) + + x = torch.randn(1, 1, 64, device=device) + t.eval() + params = TensorDict.from_module(t) + if not stack: + params = params.expand(2).to_tensordict().lock_() + else: + params = torch.stack([params, params.clone()], 0).lock_() + + def fun(x, params): + with params.to_module(t): + return t(x) + + vfun = vmap(fun, (None, 0)) + + if tdmodule: + data = TensorDict({"x": x}, []) + vfun(data, params) + benchmark(vfun, data, params) + else: + vfun(x, params) + benchmark(vfun, x, params) + + @torch.no_grad() @pytest.mark.skipif( not torch.cuda.device_count(), reason="cuda device required for test" @@ -208,6 +272,53 @@ def test_vmap_transformer_speed(benchmark, stack, tdmodule): benchmark(fun, x, x, params) +@torch.no_grad() +@pytest.mark.skipif( + not torch.cuda.device_count(), reason="cuda device required for test" +) +@pytest.mark.parametrize("stack", [True, False]) +@pytest.mark.parametrize("tdmodule", [True, False]) +def test_vmap_transformer_speed_decorator(benchmark, stack, tdmodule): + # tests speed of vmapping over a transformer + device = "cuda" if torch.cuda.device_count() else "cpu" + t = torch.nn.Transformer( + 8, + dim_feedforward=8, + device=device, + batch_first=False, + ) + if tdmodule: + t = TensorDictModule(t, in_keys=["x", "x"], out_keys=["y"]) + + x = torch.randn(2, 2, 8, device=device) + t.eval() + params = TensorDict.from_module(t) + if not stack: + params = params.expand(2).to_tensordict().lock_() + else: + params = torch.stack([params, params.clone()], 0).lock_() + + if tdmodule: + + def fun(x, params): + with params.to_module(t): + return t(x) + + vfun = vmap(fun, (None, 0)) + data = TensorDict({"x": x}, []) + vfun(data, params) + benchmark(vfun, data, params) + else: + + def fun(x, params): + with params.to_module(t): + return t(x, x) + + vfun = vmap(fun, (None, 0)) + vfun(x, params) + benchmark(vfun, x, params) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 84b5bf1bb..488681dcd 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -863,6 +863,8 @@ def hook_in( batch_size=n, vmap_level=vmap_level, ): + if _is_tensor_collection(type(tensor)): + return tensor._remove_batch_dim(vmap_level, batch_size, out_dim) return _remove_batch_dim(tensor, vmap_level, batch_size, out_dim) out.hook_out = hook_out diff --git a/tensordict/_td.py b/tensordict/_td.py index fd84fa9a3..f27c6f3ac 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -259,20 +259,23 @@ def from_module( return td_struct @as_decorator() - def to_module(self, module, return_swap: bool = True, swap_dest=None): - + def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None): # we use __dict__ directly to avoid the getattr/setattr overhead whenever we can __dict__ = module.__dict__ - out = None + swap = None has_set_device = False + if memo is None: + memo = {} if return_swap: # this could break if the device and batch-size are not congruent. # For batch-size it is a minor issue (unlikely that a td with batch-size # is passed with to_module) but for the device it could be a problem. if swap_dest is None: - out = self.empty() + swap = self.empty() + swap.clear_device_() else: - out = swap_dest + swap = swap_dest + memo[id(module)] = swap for key, value in self.items(): if isinstance(value, (Tensor, ftdim.Tensor)): @@ -295,24 +298,30 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None): local_dest = swap_dest._get_str(key, default=NO_DEFAULT) else: local_dest = None - local_out = value.to_module( - __dict__["_modules"][key], - return_swap=return_swap, - swap_dest=local_dest, - ) - if return_swap: + child = __dict__["_modules"][key] + if id(child) in memo: + local_out = memo[id(child)] + else: + local_out = value.to_module( + child, + return_swap=return_swap, + swap_dest=local_dest, + memo=memo, + ) # we don't want to do this op more than once - if ( + if return_swap and ( not has_set_device - and out.device is not None + and swap.device is not None and local_out.device is not None - and local_out.device != out.device + and local_out.device != swap.device ): has_set_device = True # map out to the local_out device - out = out.to(device=local_out.device) - out._set_str(key, local_out, inplace=False, validated=True) - return out + swap = swap.to(device=local_out.device) + + if return_swap: + swap._set_str(key, local_out, inplace=False, validated=True) + return swap def __ne__(self, other: object) -> T | bool: if _is_tensorclass(other): @@ -1237,8 +1246,8 @@ def _set_str( if not validated: value = self._validate_value(value, check_shape=True) if not inplace: - if self.is_locked: - raise RuntimeError(_LOCK_ERROR) + # if self.is_locked: + # raise RuntimeError(_LOCK_ERROR) self._tensordict[key] = value else: try: diff --git a/tensordict/base.py b/tensordict/base.py index aec2f9fec..1335ec6f9 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -328,7 +328,9 @@ def from_module(module, as_module: bool = False, lock: bool = True): ... @abc.abstractmethod - def to_module(self, module: nn.Module, return_swap: bool = False, swap_dest=None): + def to_module( + self, module: nn.Module, return_swap: bool = False, swap_dest=None, memo=None + ): """Writes the content of a TensorDictBase instance onto a given nn.Module attributes, recursively. Args: @@ -337,6 +339,10 @@ def to_module(self, module: nn.Module, return_swap: bool = False, swap_dest=None will be returned. Defaults to ``False``. swap_dest (TensorDictBase, optional): if ``return_swap`` is ``True``, the tensordict where the swap should be written. + memo (dict, optional): when the same module is present multiple times + in the input module, a memo is used to avoid fetching the params + that have just been set. This argument should be ignored during + regular calls to `to_module`. Examples: >>> from torch import nn diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index b5d379b91..e8fcd9bd8 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -1091,7 +1091,7 @@ def __init__( ) self.module = module - make_functional(self, keep_params=True, return_params=False) + # make_functional(self, keep_params=True, return_params=False) @property def is_functional(self) -> bool: diff --git a/tensordict/utils.py b/tensordict/utils.py index 7ec0056b5..8c462d1dc 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1123,10 +1123,11 @@ def new_func(_self, *args, **kwargs): out = func(_self, *args, **kwargs) if self.attr is not None: _attr_post = getattr(_self, self.attr) - if self.attr is None or (_attr_post is not _attr_pre): - out._last_op = (new_func.__name__, (args, kwargs, _self)) - else: - out._last_op = None + if out is not None: + if self.attr is None or (_attr_post is not _attr_pre): + out._last_op = (new_func.__name__, (args, kwargs, _self)) + else: + out._last_op = None return out return new_func diff --git a/test/test_nn.py b/test/test_nn.py index c45158340..c3c9f1de9 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3198,6 +3198,88 @@ def test_sd_module(self): assert isinstance(val, nn.Parameter) +@pytest.mark.parametrize( + "module_name,input_name", [["_module_shared", "_x"], ["_transformer", "_tuple_x"]] +) +class TestToModule: + @property + def _transformer(self): + # we use transformer because it's deep, has buffers etc. + return nn.Transformer(d_model=8, dim_feedforward=8).eval() + + @property + def _module_shared(self): + # a module with the same layer appearing twice + l0 = nn.Linear(8, 9) + l1 = nn.Linear(9, 8) + return nn.Sequential( + l0, + l1, + nn.Sequential( + l0, + ), + ) + + @property + def _tuple_x(self): + x = torch.randn(2, 2, 8) + return (x, x) + + @property + def _x(self): + return (torch.randn(2, 2, 8),) + + def test_static(self, module_name, input_name): + torch.manual_seed(0) + module = getattr(self, module_name) + x = getattr(self, input_name) + params = TensorDict.from_module(module) + params0 = params.clone().zero_() + y = module(*x) + params0.to_module(module) + y0 = module(*x) + params.to_module(module) + y1 = module(*x) + torch.testing.assert_close(y, y1) + assert (y0 == 0).all() + assert (y0 != y1).all() + + def test_cm(self, module_name, input_name): + torch.manual_seed(0) + module = getattr(self, module_name) + x = getattr(self, input_name) + params = TensorDict.from_module(module) + params0 = params.clone().apply( + lambda t, p: nn.Parameter(t * 0) if isinstance(p, nn.Parameter) else t * 0, + params, + ) + y = module(*x) + with params0.to_module(module): + y0 = module(*x) + assert (params0 == TensorDict.from_module(module)).all() + y1 = module(*x) + torch.testing.assert_close(y, y1) + assert (y0 == 0).all() + assert (y0 != y1).all() + assert (TensorDict.from_module(module) == params).all() + + def test_cm_meta(self, module_name, input_name): + torch.manual_seed(0) + module = getattr(self, module_name) + x = getattr(self, input_name) + params = TensorDict.from_module(module) + params_meta = params.detach().to("meta") + y = module(*x) + with params_meta.to_module(module): + module_meta = copy.deepcopy(module) + y1 = module(*x) + with params.to_module(module_meta): + y2 = module_meta(*x) + torch.testing.assert_close(y, y1) + torch.testing.assert_close(y, y2) + assert (TensorDict.from_module(module) == params).all() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From ab88669b507620476f2d0dcae1bec50bed802ff3 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 21 Nov 2023 21:03:19 +0000 Subject: [PATCH 2/6] amend --- tensordict/nn/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index e8fcd9bd8..b5d379b91 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -1091,7 +1091,7 @@ def __init__( ) self.module = module - # make_functional(self, keep_params=True, return_params=False) + make_functional(self, keep_params=True, return_params=False) @property def is_functional(self) -> bool: From 7b2617586446d614b114c7ef666b79ba7088d389 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 21 Nov 2023 21:04:48 +0000 Subject: [PATCH 3/6] amend --- tensordict/_td.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index f27c6f3ac..2f67a025e 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1246,8 +1246,8 @@ def _set_str( if not validated: value = self._validate_value(value, check_shape=True) if not inplace: - # if self.is_locked: - # raise RuntimeError(_LOCK_ERROR) + if self.is_locked: + raise RuntimeError(_LOCK_ERROR) self._tensordict[key] = value else: try: From a6ccca9c7e26cc4e94260aac4dbed139c703a106 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 22 Nov 2023 08:12:12 +0000 Subject: [PATCH 4/6] amend --- tensordict/_td.py | 5 +++++ test/test_nn.py | 18 +++++++++++------- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 2f67a025e..58214e586 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -263,6 +263,7 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) # we use __dict__ directly to avoid the getattr/setattr overhead whenever we can __dict__ = module.__dict__ swap = None + _is_locked = False has_set_device = False if memo is None: memo = {} @@ -275,6 +276,8 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) swap.clear_device_() else: swap = swap_dest + _is_locked = swap._is_locked + swap._is_locked = False memo[id(module)] = swap for key, value in self.items(): @@ -321,6 +324,8 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) if return_swap: swap._set_str(key, local_out, inplace=False, validated=True) + if swap is not None and _is_locked: + self._is_locked = _is_locked return swap def __ne__(self, other: object) -> T | bool: diff --git a/test/test_nn.py b/test/test_nn.py index c3c9f1de9..13e420a67 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3201,6 +3201,7 @@ def test_sd_module(self): @pytest.mark.parametrize( "module_name,input_name", [["_module_shared", "_x"], ["_transformer", "_tuple_x"]] ) +@pytest.mark.parametrize("as_module", [True, False]) class TestToModule: @property def _transformer(self): @@ -3229,12 +3230,15 @@ def _tuple_x(self): def _x(self): return (torch.randn(2, 2, 8),) - def test_static(self, module_name, input_name): + def test_static(self, module_name, input_name, as_module): torch.manual_seed(0) module = getattr(self, module_name) x = getattr(self, input_name) - params = TensorDict.from_module(module) - params0 = params.clone().zero_() + params = TensorDict.from_module(module, as_module=as_module) + params0 = params.clone().apply( + lambda t, p: nn.Parameter(t * 0) if isinstance(p, nn.Parameter) else t * 0, + params, + ) y = module(*x) params0.to_module(module) y0 = module(*x) @@ -3244,11 +3248,11 @@ def test_static(self, module_name, input_name): assert (y0 == 0).all() assert (y0 != y1).all() - def test_cm(self, module_name, input_name): + def test_cm(self, module_name, input_name, as_module): torch.manual_seed(0) module = getattr(self, module_name) x = getattr(self, input_name) - params = TensorDict.from_module(module) + params = TensorDict.from_module(module, as_module=as_module) params0 = params.clone().apply( lambda t, p: nn.Parameter(t * 0) if isinstance(p, nn.Parameter) else t * 0, params, @@ -3263,11 +3267,11 @@ def test_cm(self, module_name, input_name): assert (y0 != y1).all() assert (TensorDict.from_module(module) == params).all() - def test_cm_meta(self, module_name, input_name): + def test_cm_meta(self, module_name, input_name, as_module): torch.manual_seed(0) module = getattr(self, module_name) x = getattr(self, input_name) - params = TensorDict.from_module(module) + params = TensorDict.from_module(module, as_module=as_module) params_meta = params.detach().to("meta") y = module(*x) with params_meta.to_module(module): From 9040fa3b01c96e6e11547b74e248427962727e0d Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 22 Nov 2023 09:15:35 +0000 Subject: [PATCH 5/6] remove ops from workflow --- .github/workflows/benchmarks_pr.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/benchmarks_pr.yml b/.github/workflows/benchmarks_pr.yml index 96d6bac11..ac77fbabf 100644 --- a/.github/workflows/benchmarks_pr.yml +++ b/.github/workflows/benchmarks_pr.yml @@ -52,8 +52,8 @@ jobs: token: ${{ secrets.GITHUB_TOKEN }} benchmark-file: ${{ env.CONTENDER_JSON }} comparison-benchmark-file: ${{ env.BASELINE_JSON }} - benchmark-metrics: 'name,max,mean,ops' - comparison-benchmark-metric: 'ops' + benchmark-metrics: 'name,max,mean' + comparison-benchmark-metric: 'mean' comparison-higher-is-better: true comparison-threshold: 5 benchmark-title: 'Result of CPU Benchmark Tests' @@ -139,8 +139,8 @@ jobs: token: ${{ secrets.GITHUB_TOKEN }} benchmark-file: ${{ env.CONTENDER_JSON }} comparison-benchmark-file: ${{ env.BASELINE_JSON }} - benchmark-metrics: 'name,max,mean,ops' - comparison-benchmark-metric: 'ops' + benchmark-metrics: 'name,max,mean' + comparison-benchmark-metric: 'mean' comparison-higher-is-better: true comparison-threshold: 5 benchmark-title: 'Result of GPU Benchmark Tests' From 1d97ff23a65a71af5af1d4d03c23c095ff7f222a Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 22 Nov 2023 11:02:44 +0000 Subject: [PATCH 6/6] amend --- .github/workflows/benchmarks_pr.yml | 8 ++++---- tensordict/_td.py | 6 +----- tensordict/base.py | 3 ++- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/.github/workflows/benchmarks_pr.yml b/.github/workflows/benchmarks_pr.yml index ac77fbabf..96d6bac11 100644 --- a/.github/workflows/benchmarks_pr.yml +++ b/.github/workflows/benchmarks_pr.yml @@ -52,8 +52,8 @@ jobs: token: ${{ secrets.GITHUB_TOKEN }} benchmark-file: ${{ env.CONTENDER_JSON }} comparison-benchmark-file: ${{ env.BASELINE_JSON }} - benchmark-metrics: 'name,max,mean' - comparison-benchmark-metric: 'mean' + benchmark-metrics: 'name,max,mean,ops' + comparison-benchmark-metric: 'ops' comparison-higher-is-better: true comparison-threshold: 5 benchmark-title: 'Result of CPU Benchmark Tests' @@ -139,8 +139,8 @@ jobs: token: ${{ secrets.GITHUB_TOKEN }} benchmark-file: ${{ env.CONTENDER_JSON }} comparison-benchmark-file: ${{ env.BASELINE_JSON }} - benchmark-metrics: 'name,max,mean' - comparison-benchmark-metric: 'mean' + benchmark-metrics: 'name,max,mean,ops' + comparison-benchmark-metric: 'ops' comparison-higher-is-better: true comparison-threshold: 5 benchmark-title: 'Result of GPU Benchmark Tests' diff --git a/tensordict/_td.py b/tensordict/_td.py index 58214e586..2e01b91cd 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -263,7 +263,6 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) # we use __dict__ directly to avoid the getattr/setattr overhead whenever we can __dict__ = module.__dict__ swap = None - _is_locked = False has_set_device = False if memo is None: memo = {} @@ -276,8 +275,6 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) swap.clear_device_() else: swap = swap_dest - _is_locked = swap._is_locked - swap._is_locked = False memo[id(module)] = swap for key, value in self.items(): @@ -323,9 +320,8 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) swap = swap.to(device=local_out.device) if return_swap: + assert local_out is not None, key swap._set_str(key, local_out, inplace=False, validated=True) - if swap is not None and _is_locked: - self._is_locked = _is_locked return swap def __ne__(self, other: object) -> T | bool: diff --git a/tensordict/base.py b/tensordict/base.py index 1335ec6f9..74e4980f3 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3125,7 +3125,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): return self.lock_() if last_op == self.__class__.to_module.__name__: if is_tensor_collection(out): - return self.to_module(*args, **kwargs, swap_dest=out) + with out.unlock_(): + return self.to_module(*args, **kwargs, swap_dest=out) else: raise RuntimeError( "to_module cannot be used as a decorator when return_swap=False."