Skip to content

Commit

Permalink
test: add test for pytree_node with init
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Sep 8, 2024
1 parent c53a0b1 commit 77f848f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
4 changes: 2 additions & 2 deletions optree/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def field( # type: ignore[no-redef] # pylint: disable=function-redefined,too-ma

if not init and pytree_node:
raise TypeError(
'PyTree node field must be included in `__init__()`. '
'Or you can explicitly set `optree.dataclasses.field(init=False, pytree_node=False)`.',
'`pytree_node=True` is not allowed for non-init fields. '
'Please explicitly set `optree.dataclasses.field(init=False, pytree_node=False)`.',
)

return dataclasses.field(**kwargs) # pylint: disable=invalid-field-call
Expand Down
36 changes: 35 additions & 1 deletion tests/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,41 @@ def test_invalid_parameters():
dataclasses.dataclass(slots=False)


def test_init_args():
def test_field_with_init():
with pytest.raises(
TypeError,
match=re.escape("field() got an unexpected keyword argument 'pytree_node'"),
):
dataclasses.field(pytree_node=True)

f1 = optree.dataclasses.field()
assert f1.metadata['pytree_node'] is True
f2 = optree.dataclasses.field(pytree_node=False)
assert f2.metadata['pytree_node'] is False
f3 = optree.dataclasses.field(pytree_node=True)
assert f3.metadata['pytree_node'] is True
with pytest.raises(
TypeError,
match=re.escape('`pytree_node=True` is not allowed for non-init fields.'),
):
optree.dataclasses.field(init=False)
f4 = optree.dataclasses.field(init=False, metadata={'pytree_node': False})
assert f4.metadata['pytree_node'] is False
with pytest.raises(
TypeError,
match=re.escape('`pytree_node=True` is not allowed for non-init fields.'),
):
optree.dataclasses.field(init=False, metadata={'pytree_node': True})
f5 = optree.dataclasses.field(init=False, pytree_node=False)
assert f5.metadata['pytree_node'] is False
with pytest.raises(
TypeError,
match=re.escape('`pytree_node=True` is not allowed for non-init fields.'),
):
optree.dataclasses.field(init=False, pytree_node=True)


def test_dataclass_with_init():
@optree.dataclasses.dataclass(namespace='some-namespace')
class Foo:
a: int
Expand Down

0 comments on commit 77f848f

Please sign in to comment.