Skip to content

Commit

Permalink
feat(dataclasses): add dataclasses integration
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Jul 2, 2024
1 parent ddbd0e5 commit c7e2cda
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ setattr
delattr
typecheck
subclassed
dataclass
dataclasses
subpath
accessor
Expand Down
2 changes: 1 addition & 1 deletion optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""OpTree: Optimized PyTree Utilities."""

from optree import accessor, functools, integration, typing
from optree import accessor, dataclasses, functools, integration, typing
from optree.accessor import (
AutoEntry,
DataclassEntry,
Expand Down
239 changes: 239 additions & 0 deletions optree/dataclasses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
# Copyright 2022-2024 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""PyTree integration with :mod:`dataclasses`."""

from __future__ import annotations

import contextlib
import dataclasses
import sys
import types
from dataclasses import * # noqa: F401,F403,RUF100 # pylint: disable=wildcard-import,unused-wildcard-import
from typing import Any, Callable, TypeVar, overload
from typing_extensions import dataclass_transform # Python 3.11+

from optree.accessor import DataclassEntry
from optree.registry import register_pytree_node


__all__ = dataclasses.__all__.copy()


_FIELDS = '__optree_dataclass_fields__'
_PYTREE_NODE_DEFAULT: bool = True


def field( # type: ignore[no-redef] # pylint: disable=function-redefined,too-many-arguments
*,
default: Any = dataclasses.MISSING,
default_factory: Any = dataclasses.MISSING,
init: bool = True,
repr: bool = True, # pylint: disable=redefined-builtin
hash: bool | None = None, # pylint: disable=redefined-builtin
compare: bool = True,
metadata: dict[str, Any] | None = None,
kw_only: bool = dataclasses.MISSING, # type: ignore[assignment] # Python 3.10+
pytree_node: bool = _PYTREE_NODE_DEFAULT,
) -> dataclasses.Field:
"""Field factory for :func:`dataclass`."""
metadata = metadata or {}
metadata['pytree_node'] = pytree_node

Check warning on line 52 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L51-L52

Added lines #L51 - L52 were not covered by tests

kwargs = {

Check warning on line 54 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L54

Added line #L54 was not covered by tests
'default': default,
'default_factory': default_factory,
'init': init,
'repr': repr,
'hash': hash,
'compare': compare,
'metadata': metadata,
}
if sys.version_info >= (3, 10):
kwargs['kw_only'] = kw_only

Check warning on line 64 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L63-L64

Added lines #L63 - L64 were not covered by tests

return dataclasses.field(**kwargs) # pylint: disable=invalid-field-call

Check warning on line 66 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L66

Added line #L66 was not covered by tests


T = TypeVar('T')
U = TypeVar('U')
TypeT = TypeVar('TypeT', bound=type)


@overload # type: ignore[no-redef]
@dataclass_transform(field_specifiers=(field,))
def dataclass( # pylint: disable=too-many-arguments
cls: None,
*,
namespace: str,
init: bool = True,
repr: bool = True, # pylint: disable=redefined-builtin
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
match_args: bool = True, # Python 3.10+
kw_only: bool = False, # Python 3.10+
slots: bool = False, # Python 3.10+
weakref_slot: bool = False, # Python 3.11+
) -> Callable[[TypeT], TypeT]: ...


@overload
@dataclass_transform(field_specifiers=(field,))
def dataclass( # pylint: disable=too-many-arguments
cls: TypeT,
*,
namespace: str,
init: bool = True,
repr: bool = True, # pylint: disable=redefined-builtin
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
match_args: bool = True, # Python 3.10+
kw_only: bool = False, # Python 3.10+
slots: bool = False, # Python 3.10+
weakref_slot: bool = False, # Python 3.11+
) -> TypeT: ...


@dataclass_transform(field_specifiers=(field,))
def dataclass( # noqa: C901 # pylint: disable=function-redefined,too-many-arguments,too-many-locals
cls: TypeT | None = None,
*,
namespace: str,
init: bool = True,
repr: bool = True, # pylint: disable=redefined-builtin
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
match_args: bool = True, # Python 3.10+
kw_only: bool = False, # Python 3.10+
slots: bool = False, # Python 3.10+
weakref_slot: bool = False, # Python 3.11+
) -> TypeT | Callable[[TypeT], TypeT]:
"""Dataclass decorator with PyTree integration."""
kwargs = {

Check warning on line 129 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L129

Added line #L129 was not covered by tests
'init': init,
'repr': repr,
'eq': eq,
'order': order,
'unsafe_hash': unsafe_hash,
'frozen': frozen,
}
if sys.version_info >= (3, 10):
kwargs['match_args'] = match_args
kwargs['kw_only'] = kw_only
kwargs['slots'] = slots
if sys.version_info >= (3, 11):
kwargs['weakref_slot'] = weakref_slot

Check warning on line 142 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L137-L142

Added lines #L137 - L142 were not covered by tests

if cls is None:

Check warning on line 144 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L144

Added line #L144 was not covered by tests

def decorator(cls: TypeT) -> TypeT:
return dataclass(cls, namespace=namespace, **kwargs) # type: ignore[call-overload]

Check warning on line 147 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L146-L147

Added lines #L146 - L147 were not covered by tests

return decorator

Check warning on line 149 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L149

Added line #L149 was not covered by tests

if not isinstance(cls, type):
raise TypeError(f'@{__name__}.dataclass() can only be used with classes, not {cls!r}.')
if _FIELDS in cls.__dict__:
raise TypeError(

Check warning on line 154 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L151-L154

Added lines #L151 - L154 were not covered by tests
f'@{__name__}.dataclass() cannot be applied to {cls.__name__} more than once.',
)

cls = dataclasses.dataclass(cls, **kwargs) # type: ignore[assignment]

Check warning on line 158 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L158

Added line #L158 was not covered by tests

children_fields = {}
metadata_fields = {}
for f in dataclasses.fields(cls):
if f.metadata.get('pytree_node', _PYTREE_NODE_DEFAULT):
if not f.init:
raise TypeError(f'PyTree node field {f.name!r} must be included in __init__.')
children_fields[f.name] = f
elif f.init:
metadata_fields[f.name] = f

Check warning on line 168 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L160-L168

Added lines #L160 - L168 were not covered by tests

children_fields = types.MappingProxyType(children_fields)
metadata_fields = types.MappingProxyType(metadata_fields)
setattr(cls, _FIELDS, (children_fields, metadata_fields))

Check warning on line 172 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L170-L172

Added lines #L170 - L172 were not covered by tests

def flatten_func(obj: T) -> tuple[tuple[U, ...], tuple[tuple[str, Any], ...], tuple[str, ...]]:
children = tuple(getattr(obj, name) for name in children_fields)
metadata = tuple((name, getattr(obj, name)) for name in metadata_fields)
return children, metadata, tuple(children_fields)

Check warning on line 177 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L174-L177

Added lines #L174 - L177 were not covered by tests

def unflatten_func(metadata: tuple[tuple[str, Any], ...], children: tuple[U, ...]) -> T: # type: ignore[type-var]
return cls(*children, **dict(metadata))

Check warning on line 180 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L179-L180

Added lines #L179 - L180 were not covered by tests

return register_pytree_node( # type: ignore[return-value]

Check warning on line 182 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L182

Added line #L182 was not covered by tests
cls,
flatten_func,
unflatten_func, # type: ignore[arg-type]
path_entry_type=DataclassEntry,
namespace=namespace,
)


def make_dataclass( # type: ignore[no-redef] # pylint: disable=function-redefined,too-many-arguments,too-many-locals
cls_name: str,
fields: dict[str, Any], # pylint: disable=redefined-outer-name
*,
namespace: str,
bases: tuple[type, ...] = (),
ns: dict[str, Any] | None = None, # redirect to `namespace` to `dataclasses.make_dataclass()`
init: bool = True,
repr: bool = True, # pylint: disable=redefined-builtin
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
match_args: bool = True, # Python 3.10+
kw_only: bool = False, # Python 3.10+
slots: bool = False, # Python 3.10+
weakref_slot: bool = False, # Python 3.11+
module: str | None = None,
) -> type:
"""Make a dataclass with PyTree integration."""
kwargs = {

Check warning on line 211 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L211

Added line #L211 was not covered by tests
'bases': bases,
'namespace': ns,
'init': init,
'repr': repr,
'eq': eq,
'order': order,
'unsafe_hash': unsafe_hash,
'frozen': frozen,
}
if sys.version_info >= (3, 10):
kwargs['match_args'] = match_args
kwargs['kw_only'] = kw_only
kwargs['slots'] = slots
if sys.version_info >= (3, 11):
kwargs['weakref_slot'] = weakref_slot
if sys.version_info >= (3, 12):
if module is None:
try:

Check warning on line 229 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L221-L229

Added lines #L221 - L229 were not covered by tests
# pylint: disable-next=protected-access
module = sys._getframemodulename(1) or '__main__' # type: ignore[attr-defined]
except AttributeError:
with contextlib.suppress(AttributeError, ValueError):

Check warning on line 233 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L231-L233

Added lines #L231 - L233 were not covered by tests
# pylint: disable-next=protected-access
module = sys._getframe(1).f_globals.get('__name__', '__main__')
kwargs['module'] = module

Check warning on line 236 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L235-L236

Added lines #L235 - L236 were not covered by tests

cls = dataclasses.make_dataclass(cls_name, fields=fields, **kwargs) # type: ignore[arg-type]
return dataclass(cls, namespace=namespace) # type: ignore[call-overload]

Check warning on line 239 in optree/dataclasses.py

View check run for this annotation

Codecov / codecov/patch

optree/dataclasses.py#L238-L239

Added lines #L238 - L239 were not covered by tests

0 comments on commit c7e2cda

Please sign in to comment.