From 7dd385b96dc0a75a851ca8a100f4765791cbd7aa Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 10 Feb 2025 11:23:45 +0000 Subject: [PATCH] [Deprecation] Softly deprecate extra-tensors wrt out_keys ghstack-source-id: aea9814fecbab903ad22ae54903a2921f4b88c5b Pull Request resolved: https://github.com/pytorch/tensordict/pull/1215 --- tensordict/nn/common.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 02b98cf30..040b7f50a 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -1055,7 +1055,21 @@ def _write_to_tensordict( tensordict_out = TensorDict() else: tensordict_out = tensordict - for _out_key, _tensor in _zip_strict(out_keys, tensors): + if len(tensors) > len(out_keys): + incipit = "There are more tensors than out_keys. " + elif len(out_keys) > len(tensors): + incipit = "There are more out_keys than tensors. " + else: + incipit = None + if incipit is not None: + warnings.warn( + incipit + "This is currently " + "allowed but it will be deprecated in v0.9. To silence this warning, " + "make sure the number of out_keys matches the number of outputs of the " + "network.", + category=DeprecationWarning, + ) + for _out_key, _tensor in zip(out_keys, tensors): if _out_key != "_": tensordict_out.set(_out_key, TensorDict.from_any(_tensor)) return tensordict_out