From c9f94389c506a638b11f6326229c3717e2fa7e33 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 18 Aug 2024 15:47:27 +0100 Subject: [PATCH] refactor: Simplify `SchemaBase.copy` (#3543) * refactor: Fix `C901` complexity in `SchemaBase.copy` The nested functions do not reference `self` and can be split out, much like `.to_dict` -> `_todict`. Saw this as a chance to add annotations as well. I think this helps illustrate that `.copy` is a noop for anything other than `SchemaBase | dict | list` - which I think might need to be addressed in the future. * refactor: Further simplify `SchemaBase.copy` 10 lines shorter and is no longer constrained to `list` on the assert. May be slightly faster on `python<3.11` which do not have zero-cost exceptions https://docs.python.org/3.11/whatsnew/3.11.html#misc * refactor(perf): Merge identical `_shallow_copy` branches Remembered a related `ruff` rule [FURB145](https://docs.astral.sh/ruff/rules/slice-copy/). Wouldn't have made this fix, but likely would have led someone there * refactor(perf): Reduce `_deep_copy` - Initialize an empty set **once** at origin, rather than creating a new list per iteration - Renamed to `by_ref` in the new private function, to better describe the operation. - No change to public API. - Define a partial `copy` to reduce repetition - Use a genexpr for `args`, to avoid unpacking twice --- altair/utils/schemapi.py | 90 +++++++++++++++++++------------------- tools/schemapi/schemapi.py | 90 +++++++++++++++++++------------------- 2 files changed, 88 insertions(+), 92 deletions(-) diff --git a/altair/utils/schemapi.py b/altair/utils/schemapi.py index fdf0d6594..9bcf039c3 100644 --- a/altair/utils/schemapi.py +++ b/altair/utils/schemapi.py @@ -25,6 +25,7 @@ Sequence, TypeVar, Union, + cast, overload, ) from typing_extensions import TypeAlias @@ -833,6 +834,38 @@ def is_undefined(obj: Any) -> TypeIs[UndefinedType]: return obj is Undefined +@overload +def _shallow_copy(obj: _CopyImpl) -> _CopyImpl: ... +@overload +def _shallow_copy(obj: Any) -> Any: ... +def _shallow_copy(obj: _CopyImpl | Any) -> _CopyImpl | Any: + if isinstance(obj, SchemaBase): + return obj.copy(deep=False) + elif isinstance(obj, (list, dict)): + return obj.copy() + else: + return obj + + +@overload +def _deep_copy(obj: _CopyImpl, by_ref: set[str]) -> _CopyImpl: ... +@overload +def _deep_copy(obj: Any, by_ref: set[str]) -> Any: ... +def _deep_copy(obj: _CopyImpl | Any, by_ref: set[str]) -> _CopyImpl | Any: + copy = partial(_deep_copy, by_ref=by_ref) + if isinstance(obj, SchemaBase): + args = (copy(arg) for arg in obj._args) + kwds = {k: (copy(v) if k not in by_ref else v) for k, v in obj._kwds.items()} + with debug_mode(False): + return obj.__class__(*args, **kwds) + elif isinstance(obj, list): + return [copy(v) for v in obj] + elif isinstance(obj, dict): + return {k: (copy(v) if k not in by_ref else v) for k, v in obj.items()} + else: + return obj + + class SchemaBase: """ Base class for schema wrappers. @@ -870,7 +903,7 @@ def __init__(self, *args: Any, **kwds: Any) -> None: if DEBUG_MODE and self._class_is_valid_at_instantiation: self.to_dict(validate=True) - def copy( # noqa: C901 + def copy( self, deep: bool | Iterable[Any] = True, ignore: list[str] | None = None ) -> Self: """ @@ -887,53 +920,11 @@ def copy( # noqa: C901 A list of keys for which the contents should not be copied, but only stored by reference. """ - - def _shallow_copy(obj): - if isinstance(obj, SchemaBase): - return obj.copy(deep=False) - elif isinstance(obj, list): - return obj[:] - elif isinstance(obj, dict): - return obj.copy() - else: - return obj - - def _deep_copy(obj, ignore: list[str] | None = None): - if ignore is None: - ignore = [] - if isinstance(obj, SchemaBase): - args = tuple(_deep_copy(arg) for arg in obj._args) - kwds = { - k: (_deep_copy(v, ignore=ignore) if k not in ignore else v) - for k, v in obj._kwds.items() - } - with debug_mode(False): - return obj.__class__(*args, **kwds) - elif isinstance(obj, list): - return [_deep_copy(v, ignore=ignore) for v in obj] - elif isinstance(obj, dict): - return { - k: (_deep_copy(v, ignore=ignore) if k not in ignore else v) - for k, v in obj.items() - } - else: - return obj - - try: - deep = list(deep) # type: ignore[arg-type] - except TypeError: - deep_is_list = False - else: - deep_is_list = True - - if deep and not deep_is_list: - return _deep_copy(self, ignore=ignore) - + if deep is True: + return cast("Self", _deep_copy(self, set(ignore) if ignore else set())) with debug_mode(False): copy = self.__class__(*self._args, **self._kwds) - if deep_is_list: - # Assert statement is for the benefit of Mypy - assert isinstance(deep, list) + if _is_iterable(deep): for attr in deep: copy[attr] = _shallow_copy(copy._get(attr)) return copy @@ -1240,6 +1231,13 @@ def __dir__(self) -> list[str]: TSchemaBase = TypeVar("TSchemaBase", bound=SchemaBase) +_CopyImpl = TypeVar("_CopyImpl", SchemaBase, Dict[Any, Any], List[Any]) +""" +Types which have an implementation in ``SchemaBase.copy()``. + +All other types are returned **by reference**. +""" + def _is_dict(obj: Any | dict[Any, Any]) -> TypeIs[dict[Any, Any]]: return isinstance(obj, dict) diff --git a/tools/schemapi/schemapi.py b/tools/schemapi/schemapi.py index 1c756c2a2..a77afd4a4 100644 --- a/tools/schemapi/schemapi.py +++ b/tools/schemapi/schemapi.py @@ -23,6 +23,7 @@ Sequence, TypeVar, Union, + cast, overload, ) from typing_extensions import TypeAlias @@ -831,6 +832,38 @@ def is_undefined(obj: Any) -> TypeIs[UndefinedType]: return obj is Undefined +@overload +def _shallow_copy(obj: _CopyImpl) -> _CopyImpl: ... +@overload +def _shallow_copy(obj: Any) -> Any: ... +def _shallow_copy(obj: _CopyImpl | Any) -> _CopyImpl | Any: + if isinstance(obj, SchemaBase): + return obj.copy(deep=False) + elif isinstance(obj, (list, dict)): + return obj.copy() + else: + return obj + + +@overload +def _deep_copy(obj: _CopyImpl, by_ref: set[str]) -> _CopyImpl: ... +@overload +def _deep_copy(obj: Any, by_ref: set[str]) -> Any: ... +def _deep_copy(obj: _CopyImpl | Any, by_ref: set[str]) -> _CopyImpl | Any: + copy = partial(_deep_copy, by_ref=by_ref) + if isinstance(obj, SchemaBase): + args = (copy(arg) for arg in obj._args) + kwds = {k: (copy(v) if k not in by_ref else v) for k, v in obj._kwds.items()} + with debug_mode(False): + return obj.__class__(*args, **kwds) + elif isinstance(obj, list): + return [copy(v) for v in obj] + elif isinstance(obj, dict): + return {k: (copy(v) if k not in by_ref else v) for k, v in obj.items()} + else: + return obj + + class SchemaBase: """ Base class for schema wrappers. @@ -868,7 +901,7 @@ def __init__(self, *args: Any, **kwds: Any) -> None: if DEBUG_MODE and self._class_is_valid_at_instantiation: self.to_dict(validate=True) - def copy( # noqa: C901 + def copy( self, deep: bool | Iterable[Any] = True, ignore: list[str] | None = None ) -> Self: """ @@ -885,53 +918,11 @@ def copy( # noqa: C901 A list of keys for which the contents should not be copied, but only stored by reference. """ - - def _shallow_copy(obj): - if isinstance(obj, SchemaBase): - return obj.copy(deep=False) - elif isinstance(obj, list): - return obj[:] - elif isinstance(obj, dict): - return obj.copy() - else: - return obj - - def _deep_copy(obj, ignore: list[str] | None = None): - if ignore is None: - ignore = [] - if isinstance(obj, SchemaBase): - args = tuple(_deep_copy(arg) for arg in obj._args) - kwds = { - k: (_deep_copy(v, ignore=ignore) if k not in ignore else v) - for k, v in obj._kwds.items() - } - with debug_mode(False): - return obj.__class__(*args, **kwds) - elif isinstance(obj, list): - return [_deep_copy(v, ignore=ignore) for v in obj] - elif isinstance(obj, dict): - return { - k: (_deep_copy(v, ignore=ignore) if k not in ignore else v) - for k, v in obj.items() - } - else: - return obj - - try: - deep = list(deep) # type: ignore[arg-type] - except TypeError: - deep_is_list = False - else: - deep_is_list = True - - if deep and not deep_is_list: - return _deep_copy(self, ignore=ignore) - + if deep is True: + return cast("Self", _deep_copy(self, set(ignore) if ignore else set())) with debug_mode(False): copy = self.__class__(*self._args, **self._kwds) - if deep_is_list: - # Assert statement is for the benefit of Mypy - assert isinstance(deep, list) + if _is_iterable(deep): for attr in deep: copy[attr] = _shallow_copy(copy._get(attr)) return copy @@ -1238,6 +1229,13 @@ def __dir__(self) -> list[str]: TSchemaBase = TypeVar("TSchemaBase", bound=SchemaBase) +_CopyImpl = TypeVar("_CopyImpl", SchemaBase, Dict[Any, Any], List[Any]) +""" +Types which have an implementation in ``SchemaBase.copy()``. + +All other types are returned **by reference**. +""" + def _is_dict(obj: Any | dict[Any, Any]) -> TypeIs[dict[Any, Any]]: return isinstance(obj, dict)