Skip to content

Commit

Permalink
[Deprecation] Softly deprecate extra-tensors wrt out_keys
Browse files Browse the repository at this point in the history
ghstack-source-id: aea9814fecbab903ad22ae54903a2921f4b88c5b
Pull Request resolved: #1215
  • Loading branch information
vmoens committed Feb 10, 2025
1 parent ba53d07 commit 7dd385b
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,21 @@ def _write_to_tensordict(
tensordict_out = TensorDict()
else:
tensordict_out = tensordict
for _out_key, _tensor in _zip_strict(out_keys, tensors):
if len(tensors) > len(out_keys):
incipit = "There are more tensors than out_keys. "
elif len(out_keys) > len(tensors):
incipit = "There are more out_keys than tensors. "
else:
incipit = None
if incipit is not None:
warnings.warn(
incipit + "This is currently "
"allowed but it will be deprecated in v0.9. To silence this warning, "
"make sure the number of out_keys matches the number of outputs of the "
"network.",
category=DeprecationWarning,
)
for _out_key, _tensor in zip(out_keys, tensors):
if _out_key != "_":
tensordict_out.set(_out_key, TensorDict.from_any(_tensor))
return tensordict_out
Expand Down

0 comments on commit 7dd385b

Please sign in to comment.