Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow two (same/different) Batch objs to be tested for equality #1098

Merged
merged 12 commits into from
Apr 16, 2024
Merged
35 changes: 33 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ exclude = ["test/*", "examples/*", "docs/*"]

[tool.poetry.dependencies]
python = "^3.11"
deepdiff = "^7.0.1"
gymnasium = "^0.28.0"
h5py = "^3.9.0"
numba = "^0.57.1"
Expand Down
140 changes: 139 additions & 1 deletion test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import pickle
import sys
from itertools import starmap
from typing import cast
from typing import Any, cast

import networkx as nx
import numpy as np
import pytest
import torch
from deepdiff import DeepDiff

from tianshou.data import Batch, to_numpy, to_torch

Expand Down Expand Up @@ -565,6 +566,143 @@ def test_batch_standard_compatibility() -> None:
Batch()[0]


class TestBatchEquality:
@staticmethod
def test_keys_different() -> None:
batch1 = Batch(a=[1, 2], b=[100, 50])
batch2 = Batch(b=[1, 2], c=[100, 50])
assert batch1 != batch2

@staticmethod
def test_keys_missing() -> None:
batch1 = Batch(a=[1, 2], b=[2, 3, 4])
batch2 = Batch(a=[1, 2], b=[2, 3, 4])
batch2.pop("b")
assert batch1 != batch2

@staticmethod
def test_types_keys_different() -> None:
batch1 = Batch(a=[1, 2, 3], b=[4, 5])
batch2 = Batch(a=[1, 2, 3], b=Batch(a=[4, 5]))
assert batch1 != batch2

@staticmethod
def test_array_types_different() -> None:
batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5]))
batch2 = Batch(a=[1, 2, 3], b=torch.Tensor([4, 5]))
assert batch1 != batch2

@staticmethod
def test_nested_values_different() -> None:
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
batch2 = Batch(a=Batch(a=[1, 2, 4]), b=[4, 5])
assert batch1 != batch2

@staticmethod
def test_nested_shapes_different() -> None:
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5])
assert batch1 != batch2

@staticmethod
def test_slice_equal() -> None:
batch1 = Batch(a=[1, 2, 3])
assert batch1[:2] == batch1[:2]

@staticmethod
def test_slice_ellipsis_equal() -> None:
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5], c=[100, 1001, 2000])
assert batch1[..., 1:] == batch1[..., 1:]

@staticmethod
def test_empty_batches() -> None:
assert Batch() == Batch()

@staticmethod
def test_different_order_keys() -> None:
assert Batch(a=1, b=2) == Batch(b=2, a=1)

@staticmethod
def test_tuple_and_list_types() -> None:
assert Batch(a=(1, 2)) == Batch(a=[1, 2])

@staticmethod
def test_subbatch_dict_and_batch_types() -> None:
assert Batch(a={"x": 1}) == Batch(a=Batch(x=1))


class TestBatchToDict:
@staticmethod
def test_to_dict_empty_batch_no_recurse() -> None:
batch = Batch()
expected: dict[Any, Any] = {}
assert batch.to_dict() == expected

@staticmethod
def test_to_dict_with_simple_values_recurse() -> None:
batch = Batch(a=1, b="two", c=np.array([3, 4]))
expected = {"a": np.asanyarray(1), "b": "two", "c": np.array([3, 4])}
assert not DeepDiff(batch.to_dict(recurse=True), expected)

@staticmethod
def test_to_dict_simple() -> None:
batch = Batch(a=1, b="two")
expected = {"a": np.asanyarray(1), "b": "two"}
assert batch.to_dict() == expected

@staticmethod
def test_to_dict_nested_batch_no_recurse() -> None:
nested_batch = Batch(c=3)
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": nested_batch}
assert not DeepDiff(batch.to_dict(), expected)

@staticmethod
def test_to_dict_nested_batch_recurse() -> None:
nested_batch = Batch(c=3)
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": np.asanyarray(3)}}
assert not DeepDiff(batch.to_dict(recurse=True), expected)

@staticmethod
def test_to_dict_multiple_nested_batch_recurse() -> None:
nested_batch = Batch(c=Batch(e=3), d=[100, 200, 300])
batch = Batch(a=1, b=nested_batch)
expected = {
"a": np.asanyarray(1),
"b": {"c": {"e": np.asanyarray(3)}, "d": np.array([100, 200, 300])},
}
assert not DeepDiff(batch.to_dict(recurse=True), expected)

@staticmethod
def test_to_dict_array() -> None:
batch = Batch(a=np.array([1, 2, 3]))
expected = {"a": np.array([1, 2, 3])}
assert not DeepDiff(batch.to_dict(), expected)

@staticmethod
def test_to_dict_nested_batch_with_array() -> None:
nested_batch = Batch(c=np.array([4, 5]))
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": np.array([4, 5])}}
assert not DeepDiff(batch.to_dict(recurse=True), expected)

@staticmethod
def test_to_dict_torch_tensor() -> None:
t1 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
batch = Batch(a=t1)
t2 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
expected = {"a": t2}
assert not DeepDiff(batch.to_dict(), expected)

@staticmethod
def test_to_dict_nested_batch_with_torch_tensor() -> None:
nested_batch = Batch(c=torch.tensor([4, 5]).detach().cpu().numpy())
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": torch.tensor([4, 5]).detach().cpu().numpy()}}
assert not DeepDiff(batch.to_dict(recurse=True), expected)


if __name__ == "__main__":
test_batch()
test_batch_over_batch()
Expand Down
23 changes: 19 additions & 4 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import torch
from deepdiff import DeepDiff

_SingleIndexType = slice | int | EllipsisType
IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...]
Expand Down Expand Up @@ -268,6 +269,9 @@ def __repr__(self) -> str:
def __iter__(self) -> Iterator[Self]:
...

def __eq__(self, other: Any) -> bool:
...

def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray in-place."""
...
Expand Down Expand Up @@ -396,7 +400,7 @@ def split(
"""
...

def to_dict(self) -> dict[str, Any]:
def to_dict(self, recurse: bool = False) -> dict[str, Any]:
...

def to_list_of_dicts(self) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -433,11 +437,11 @@ def __init__(
# Feels like kwargs could be just merged into batch_dict in the beginning
self.__init__(kwargs, copy=copy) # type: ignore

def to_dict(self) -> dict[str, Any]:
def to_dict(self, recurse: bool = False) -> dict[str, Any]:
result = {}
for k, v in self.__dict__.items():
if isinstance(v, Batch):
v = v.to_dict()
if recurse and isinstance(v, Batch):
v = v.to_dict(recurse=recurse)
result[k] = v
return result

Expand Down Expand Up @@ -500,6 +504,17 @@ def __getitem__(self, index: str | IndexType) -> Any:
return new_batch
raise IndexError("Cannot access item from empty Batch object.")

def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False

self.to_numpy()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modifies self and other inplace on a supposedly harmless equality check - this should never happen.

Starting a new thread about the questions you raised regarding to_numpy.

  1. For other methods we have the convention that things ending in _ are inplace. So it should really be that to_numpy_ operates inplace and returns nothing, whereas to_numpy() doesn't change the original and returns a new batch. That's a breaking change, so you'd need to check all the places where to_numpy is currently being used.
  2. I think the name to_numpy is ok for the moment, since all values of batch are either batches, tensors or arrays, right? Then it's clear that to_numpy would turn all tensors into arrays.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Batch.to_torch() has same problem because it is also in-place. We can change this in another PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks for bringing it up! We should generally reduce the number of inplace operations as much as possible

other.to_numpy()
this_dict = self.to_dict(recurse=True)
other_dict = other.to_dict(recurse=True)

return not DeepDiff(this_dict, other_dict)

def __iter__(self) -> Iterator[Self]:
# TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea
if len(self.__dict__) == 0:
Expand Down
Loading