Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] intersection for assert_close #1078

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,7 +1473,13 @@ def densify(self, *, layout: torch.layout = torch.strided):
else:
raise NotImplementedError
else:
tensor = self._get_str(key).densify(layout=layout)
tensor = self._get_str(key, None)
if tensor is not None:
tensor = tensor.densify(layout=layout)
else:
from tensordict import NonTensorData

tensor = NonTensorData(None)
result._set_str(key, tensor, validated=True, inplace=False)
return result

Expand Down
26 changes: 1 addition & 25 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
_is_shared,
_KEY_ERROR,
_LOCK_ERROR,
_mismatch_keys,
_NON_STR_KEY_ERR,
_NON_STR_KEY_TUPLE_ERR,
_parse_to,
Expand Down Expand Up @@ -5128,28 +5129,3 @@ def memmap(
return_early=return_early,
share_non_tensor=share_non_tensor,
)


def _mismatch_keys(keys1, keys2):
keys1 = sorted(
keys1,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
keys2 = sorted(
keys2,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
if set(keys1) - set(keys2):
sub1 = rf"The first TD has keys {set(keys1) - set(keys2)} that the second does not have."
else:
sub1 = None
if set(keys2) - set(keys1):
sub2 = rf"The second TD has keys {set(keys2) - set(keys1)} that the first does not have."
else:
sub2 = None
main = [r"keys in tensordicts mismatch."]
if sub1 is not None:
main.append(sub1)
if sub2 is not None:
main.append(sub2)
raise KeyError(r" ".join(main))
92 changes: 82 additions & 10 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,9 +1499,33 @@ def assert_close(
rtol: float | None = None,
atol: float | None = None,
equal_nan: bool = True,
intersection: bool = False,
msg: str = "",
) -> bool:
"""Compares two tensordicts and raise an exception if their content does not match exactly."""
"""Asserts that two tensordicts, `actual` and `expected`, are element-wise equal within a tolerance for all entries.

This function checks if the elements of the `actual` tensor are close to the corresponding elements
of the `expected` tensordict, within a relative tolerance (`rtol`) and an absolute tolerance (`atol`).

It is similar to the :func:`~torch.testing.assert_close` function in PyTorch, but with tensordicts inputs.

Args:
actual (T): The tensordict containing actual values.
expected (T): The tensordict containing expected values.
rtol (float | None, optional): The relative tolerance parameter. Default is None.
atol (float | None, optional): The absolute tolerance parameter. Default is None.
equal_nan (bool, optional): If True, ``NaNs`` will be considered equal to ``NaNs``. Default is ``True``.
intersection (bool, optional): If True, only the intersection of the two tensordicts will be compared.
Default is ``False``.
msg (str, optional): An optional message to include in the assertion error if the check fails.

Returns:
bool: True if the tensors are close within the specified tolerances, raise an exception otherwise.

Raises:
AssertionError: If the tensordicts are not close within the specified tolerances.

"""
from tensordict.base import _is_tensor_collection

if not _is_tensor_collection(type(actual)) or not _is_tensor_collection(
Expand All @@ -1517,7 +1541,15 @@ def assert_close(
for sub_actual, sub_expected in _zip_strict(
actual.tensordicts, expected.tensordicts
):
assert_allclose_td(sub_actual, sub_expected, rtol=rtol, atol=atol)
assert_close(
sub_actual,
sub_expected,
rtol=rtol,
atol=atol,
msg=msg,
intersection=intersection,
equal_nan=equal_nan,
)
return True

try:
Expand All @@ -1527,12 +1559,14 @@ def assert_close(
# Persistent tensordicts do not work with is_leaf
set1 = set(actual.keys(is_leaf=lambda cls: issubclass(cls, torch.Tensor)))
set2 = set(expected.keys(is_leaf=lambda cls: issubclass(cls, torch.Tensor)))
if not (len(set1.difference(set2)) == 0 and len(set2) == len(set1)):
raise KeyError(
"actual and expected tensordict keys mismatch, "
f"keys {(set1 - set2).union(set2 - set1)} appear in one but not "
f"the other."
)
if not intersection and (
not (len(set1.difference(set2)) == 0 and len(set2) == len(set1))
):
_mismatch_keys(set1, set2)
elif intersection and set1 != set2:
actual = actual.select(*set2, strict=False)
expected = expected.select(*set1, strict=False)

keys = sorted(actual.keys(), key=str)
for key in keys:
input1 = actual.get(key)
Expand All @@ -1541,7 +1575,15 @@ def assert_close(
if is_non_tensor(input1):
# We skip non-tensor data
continue
assert_allclose_td(input1, input2, rtol=rtol, atol=atol)
assert_close(
input1,
input2,
rtol=rtol,
atol=atol,
msg=msg,
intersection=intersection,
equal_nan=equal_nan,
)
continue
elif not isinstance(input1, torch.Tensor):
continue
Expand All @@ -1560,7 +1602,12 @@ def assert_close(
new_msg = ",\t".join([local_msg, msg]) if len(msg) else local_msg
if input1.is_nested:
torch.testing.assert_close(
input1v, input2v, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=new_msg
input1v,
input2v,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
msg=new_msg,
)
else:
torch.testing.assert_close(
Expand Down Expand Up @@ -2719,3 +2766,28 @@ def _rebuild_njt_from_njt(x, values, offsets, lengths):
values,
**kwargs,
)


def _mismatch_keys(keys1, keys2):
keys1 = sorted(
keys1,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
keys2 = sorted(
keys2,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
if set(keys1) - set(keys2):
sub1 = rf"The first TD has keys {set(keys1) - set(keys2)} that the second does not have."
else:
sub1 = None
if set(keys2) - set(keys1):
sub2 = rf"The second TD has keys {set(keys2) - set(keys1)} that the first does not have."
else:
sub2 = None
main = [r"keys in tensordicts mismatch."]
if sub1 is not None:
main.append(sub1)
if sub2 is not None:
main.append(sub2)
raise KeyError(r" ".join(main))
Loading