diff --git a/altair/utils/schemapi.py b/altair/utils/schemapi.py index fdf0d6594..9bcf039c3 100644 --- a/altair/utils/schemapi.py +++ b/altair/utils/schemapi.py @@ -25,6 +25,7 @@ Sequence, TypeVar, Union, + cast, overload, ) from typing_extensions import TypeAlias @@ -833,6 +834,38 @@ def is_undefined(obj: Any) -> TypeIs[UndefinedType]: return obj is Undefined +@overload +def _shallow_copy(obj: _CopyImpl) -> _CopyImpl: ... +@overload +def _shallow_copy(obj: Any) -> Any: ... +def _shallow_copy(obj: _CopyImpl | Any) -> _CopyImpl | Any: + if isinstance(obj, SchemaBase): + return obj.copy(deep=False) + elif isinstance(obj, (list, dict)): + return obj.copy() + else: + return obj + + +@overload +def _deep_copy(obj: _CopyImpl, by_ref: set[str]) -> _CopyImpl: ... +@overload +def _deep_copy(obj: Any, by_ref: set[str]) -> Any: ... +def _deep_copy(obj: _CopyImpl | Any, by_ref: set[str]) -> _CopyImpl | Any: + copy = partial(_deep_copy, by_ref=by_ref) + if isinstance(obj, SchemaBase): + args = (copy(arg) for arg in obj._args) + kwds = {k: (copy(v) if k not in by_ref else v) for k, v in obj._kwds.items()} + with debug_mode(False): + return obj.__class__(*args, **kwds) + elif isinstance(obj, list): + return [copy(v) for v in obj] + elif isinstance(obj, dict): + return {k: (copy(v) if k not in by_ref else v) for k, v in obj.items()} + else: + return obj + + class SchemaBase: """ Base class for schema wrappers. @@ -870,7 +903,7 @@ def __init__(self, *args: Any, **kwds: Any) -> None: if DEBUG_MODE and self._class_is_valid_at_instantiation: self.to_dict(validate=True) - def copy( # noqa: C901 + def copy( self, deep: bool | Iterable[Any] = True, ignore: list[str] | None = None ) -> Self: """ @@ -887,53 +920,11 @@ def copy( # noqa: C901 A list of keys for which the contents should not be copied, but only stored by reference. """ - - def _shallow_copy(obj): - if isinstance(obj, SchemaBase): - return obj.copy(deep=False) - elif isinstance(obj, list): - return obj[:] - elif isinstance(obj, dict): - return obj.copy() - else: - return obj - - def _deep_copy(obj, ignore: list[str] | None = None): - if ignore is None: - ignore = [] - if isinstance(obj, SchemaBase): - args = tuple(_deep_copy(arg) for arg in obj._args) - kwds = { - k: (_deep_copy(v, ignore=ignore) if k not in ignore else v) - for k, v in obj._kwds.items() - } - with debug_mode(False): - return obj.__class__(*args, **kwds) - elif isinstance(obj, list): - return [_deep_copy(v, ignore=ignore) for v in obj] - elif isinstance(obj, dict): - return { - k: (_deep_copy(v, ignore=ignore) if k not in ignore else v) - for k, v in obj.items() - } - else: - return obj - - try: - deep = list(deep) # type: ignore[arg-type] - except TypeError: - deep_is_list = False - else: - deep_is_list = True - - if deep and not deep_is_list: - return _deep_copy(self, ignore=ignore) - + if deep is True: + return cast("Self", _deep_copy(self, set(ignore) if ignore else set())) with debug_mode(False): copy = self.__class__(*self._args, **self._kwds) - if deep_is_list: - # Assert statement is for the benefit of Mypy - assert isinstance(deep, list) + if _is_iterable(deep): for attr in deep: copy[attr] = _shallow_copy(copy._get(attr)) return copy @@ -1240,6 +1231,13 @@ def __dir__(self) -> list[str]: TSchemaBase = TypeVar("TSchemaBase", bound=SchemaBase) +_CopyImpl = TypeVar("_CopyImpl", SchemaBase, Dict[Any, Any], List[Any]) +""" +Types which have an implementation in ``SchemaBase.copy()``. + +All other types are returned **by reference**. +""" + def _is_dict(obj: Any | dict[Any, Any]) -> TypeIs[dict[Any, Any]]: return isinstance(obj, dict) diff --git a/tools/schemapi/schemapi.py b/tools/schemapi/schemapi.py index 1c756c2a2..a77afd4a4 100644 --- a/tools/schemapi/schemapi.py +++ b/tools/schemapi/schemapi.py @@ -23,6 +23,7 @@ Sequence, TypeVar, Union, + cast, overload, ) from typing_extensions import TypeAlias @@ -831,6 +832,38 @@ def is_undefined(obj: Any) -> TypeIs[UndefinedType]: return obj is Undefined +@overload +def _shallow_copy(obj: _CopyImpl) -> _CopyImpl: ... +@overload +def _shallow_copy(obj: Any) -> Any: ... +def _shallow_copy(obj: _CopyImpl | Any) -> _CopyImpl | Any: + if isinstance(obj, SchemaBase): + return obj.copy(deep=False) + elif isinstance(obj, (list, dict)): + return obj.copy() + else: + return obj + + +@overload +def _deep_copy(obj: _CopyImpl, by_ref: set[str]) -> _CopyImpl: ... +@overload +def _deep_copy(obj: Any, by_ref: set[str]) -> Any: ... +def _deep_copy(obj: _CopyImpl | Any, by_ref: set[str]) -> _CopyImpl | Any: + copy = partial(_deep_copy, by_ref=by_ref) + if isinstance(obj, SchemaBase): + args = (copy(arg) for arg in obj._args) + kwds = {k: (copy(v) if k not in by_ref else v) for k, v in obj._kwds.items()} + with debug_mode(False): + return obj.__class__(*args, **kwds) + elif isinstance(obj, list): + return [copy(v) for v in obj] + elif isinstance(obj, dict): + return {k: (copy(v) if k not in by_ref else v) for k, v in obj.items()} + else: + return obj + + class SchemaBase: """ Base class for schema wrappers. @@ -868,7 +901,7 @@ def __init__(self, *args: Any, **kwds: Any) -> None: if DEBUG_MODE and self._class_is_valid_at_instantiation: self.to_dict(validate=True) - def copy( # noqa: C901 + def copy( self, deep: bool | Iterable[Any] = True, ignore: list[str] | None = None ) -> Self: """ @@ -885,53 +918,11 @@ def copy( # noqa: C901 A list of keys for which the contents should not be copied, but only stored by reference. """ - - def _shallow_copy(obj): - if isinstance(obj, SchemaBase): - return obj.copy(deep=False) - elif isinstance(obj, list): - return obj[:] - elif isinstance(obj, dict): - return obj.copy() - else: - return obj - - def _deep_copy(obj, ignore: list[str] | None = None): - if ignore is None: - ignore = [] - if isinstance(obj, SchemaBase): - args = tuple(_deep_copy(arg) for arg in obj._args) - kwds = { - k: (_deep_copy(v, ignore=ignore) if k not in ignore else v) - for k, v in obj._kwds.items() - } - with debug_mode(False): - return obj.__class__(*args, **kwds) - elif isinstance(obj, list): - return [_deep_copy(v, ignore=ignore) for v in obj] - elif isinstance(obj, dict): - return { - k: (_deep_copy(v, ignore=ignore) if k not in ignore else v) - for k, v in obj.items() - } - else: - return obj - - try: - deep = list(deep) # type: ignore[arg-type] - except TypeError: - deep_is_list = False - else: - deep_is_list = True - - if deep and not deep_is_list: - return _deep_copy(self, ignore=ignore) - + if deep is True: + return cast("Self", _deep_copy(self, set(ignore) if ignore else set())) with debug_mode(False): copy = self.__class__(*self._args, **self._kwds) - if deep_is_list: - # Assert statement is for the benefit of Mypy - assert isinstance(deep, list) + if _is_iterable(deep): for attr in deep: copy[attr] = _shallow_copy(copy._get(attr)) return copy @@ -1238,6 +1229,13 @@ def __dir__(self) -> list[str]: TSchemaBase = TypeVar("TSchemaBase", bound=SchemaBase) +_CopyImpl = TypeVar("_CopyImpl", SchemaBase, Dict[Any, Any], List[Any]) +""" +Types which have an implementation in ``SchemaBase.copy()``. + +All other types are returned **by reference**. +""" + def _is_dict(obj: Any | dict[Any, Any]) -> TypeIs[dict[Any, Any]]: return isinstance(obj, dict)