From eb6febd4fc49352b31e1d2f40681ae61b719cfda Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 22 Sep 2024 16:20:56 +0100 Subject: [PATCH] fix(typing): Improve `Then` annotations, autocompletion, docs (#3567) --- altair/utils/__init__.py | 3 +- altair/utils/core.py | 4 +- altair/utils/schemapi.py | 102 ++++++++- altair/vegalite/v5/api.py | 177 ++++++++++----- altair/vegalite/v5/schema/channels.py | 307 +++++++++++++------------- tests/vegalite/v5/test_api.py | 77 +++++-- tools/generate_api_docs.py | 17 +- tools/generate_schema_wrapper.py | 34 ++- tools/schemapi/schemapi.py | 102 ++++++++- 9 files changed, 570 insertions(+), 253 deletions(-) diff --git a/altair/utils/__init__.py b/altair/utils/__init__.py index b6855e1ee..bfcb52eff 100644 --- a/altair/utils/__init__.py +++ b/altair/utils/__init__.py @@ -12,7 +12,7 @@ from .deprecation import AltairDeprecationWarning, deprecated, deprecated_warn from .html import spec_to_html from .plugin_registry import PluginRegistry -from .schemapi import Optional, SchemaBase, Undefined, is_undefined +from .schemapi import Optional, SchemaBase, SchemaLike, Undefined, is_undefined __all__ = ( "SHORTHAND_KEYS", @@ -20,6 +20,7 @@ "Optional", "PluginRegistry", "SchemaBase", + "SchemaLike", "Undefined", "deprecated", "deprecated_warn", diff --git a/altair/utils/core.py b/altair/utils/core.py index 7e8340324..734efe1f2 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -28,7 +28,7 @@ from narwhals.dependencies import get_polars, is_pandas_dataframe from narwhals.typing import IntoDataFrame -from altair.utils.schemapi import SchemaBase, Undefined +from altair.utils.schemapi import SchemaBase, SchemaLike, Undefined if sys.version_info >= (3, 12): from typing import Protocol, TypeAliasType, runtime_checkable @@ -869,6 +869,8 @@ def _wrap_in_channel(self, obj: Any, encoding: str, /): obj = {"shorthand": obj} elif isinstance(obj, (list, tuple)): return [self._wrap_in_channel(el, encoding) for el in obj] + elif isinstance(obj, SchemaLike): + obj = obj.to_dict() if channel := self.name_to_channel.get(encoding): tp = channel["value" if "value" in obj else "field"] try: diff --git a/altair/utils/schemapi.py b/altair/utils/schemapi.py index 84f5be277..ffe28ffa2 100644 --- a/altair/utils/schemapi.py +++ b/altair/utils/schemapi.py @@ -18,10 +18,12 @@ Any, Dict, Final, + Generic, Iterable, Iterator, List, Literal, + Mapping, Sequence, TypeVar, Union, @@ -41,6 +43,11 @@ # not yet be fully instantiated in case your code is being executed during import time from altair import vegalite +if sys.version_info >= (3, 12): + from typing import Protocol, TypeAliasType, runtime_checkable +else: + from typing_extensions import Protocol, TypeAliasType, runtime_checkable + if TYPE_CHECKING: from types import ModuleType from typing import ClassVar @@ -524,11 +531,7 @@ def _todict(obj: Any, context: dict[str, Any] | None, np_opt: Any, pd_opt: Any) for k, v in obj.items() if v is not Undefined } - elif ( - hasattr(obj, "to_dict") - and (module_name := obj.__module__) - and module_name.startswith("altair") - ): + elif isinstance(obj, SchemaLike): return obj.to_dict() elif pd_opt is not None and isinstance(obj, pd_opt.Timestamp): return pd_opt.Timestamp(obj).isoformat() @@ -789,6 +792,95 @@ def _get_default_error_message( return message +_JSON_VT_co = TypeVar( + "_JSON_VT_co", + Literal["string"], + Literal["object"], + Literal["array"], + covariant=True, +) +""" +One of a subset of JSON Schema `primitive types`_: + + ["string", "object", "array"] + +.. _primitive types: + https://json-schema.org/draft-07/json-schema-validation#rfc.section.6.1.1 +""" + +_TypeMap = TypeAliasType( + "_TypeMap", Mapping[Literal["type"], _JSON_VT_co], type_params=(_JSON_VT_co,) +) +""" +A single item JSON Schema using the `type`_ keyword. + +This may represent **one of**: + + {"type": "string"} + {"type": "object"} + {"type": "array"} + +.. _type: + https://json-schema.org/understanding-json-schema/reference/type +""" + +# NOTE: Type checkers want opposing things: +# - `mypy` : Covariant type variable "_JSON_VT_co" used in protocol where invariant one is expected [misc] +# - `pyright`: Type variable "_JSON_VT_co" used in generic protocol "SchemaLike" should be covariant [reportInvalidTypeVarUse] +# Siding with `pyright` as this is consistent with https://github.com/python/typeshed/blob/9e506eb5e8fc2823db8c60ad561b1145ff114947/stdlib/typing.pyi#L690 + + +@runtime_checkable +class SchemaLike(Generic[_JSON_VT_co], Protocol): # type: ignore[misc] + """ + Represents ``altair`` classes which *may* not derive ``SchemaBase``. + + Attributes + ---------- + _schema + A single item JSON Schema using the `type`_ keyword. + + Notes + ----- + Should be kept tightly defined to the **minimum** requirements for: + - Converting into a form that can be validated by `jsonschema`_. + - Avoiding calling ``.to_dict()`` on a class external to ``altair``. + - ``_schema`` is more accurately described as a ``ClassVar`` + - See `discussion`_ for blocking issue. + + .. _jsonschema: + https://github.com/python-jsonschema/jsonschema + .. _type: + https://json-schema.org/understanding-json-schema/reference/type + .. _discussion: + https://github.com/python/typing/discussions/1424 + """ + + _schema: _TypeMap[_JSON_VT_co] + + def to_dict(self, *args, **kwds) -> Any: ... + + +@runtime_checkable +class ConditionLike(SchemaLike[Literal["object"]], Protocol): + """ + Represents the wrapped state of a conditional encoding or property. + + Attributes + ---------- + condition + One or more (predicate, statement) pairs which each form a condition. + + Notes + ----- + - Can be extended with additional conditions. + - *Does not* define a default value, but can be finalized with one. + """ + + condition: Any + _schema: _TypeMap[Literal["object"]] = {"type": "object"} + + class UndefinedType: """A singleton object for marking undefined parameters.""" diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index b437c1072..fce8c080d 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -10,17 +10,7 @@ import typing as t import warnings from copy import deepcopy as _deepcopy -from typing import ( - TYPE_CHECKING, - Any, - Literal, - Protocol, - Sequence, - TypeVar, - Union, - overload, -) -from typing_extensions import TypeAlias +from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar, Union, overload import jsonschema import narwhals.stable.v1 as nw @@ -34,6 +24,7 @@ from altair.utils._vegafusion_data import using_vegafusion as _using_vegafusion from altair.utils.data import DataType from altair.utils.data import is_data_type as _is_data_type +from altair.utils.schemapi import ConditionLike, _TypeMap from .compiler import vegalite_compilers from .data import data_transformers @@ -42,10 +33,18 @@ from .schema._typing import Map from .theme import themes -if sys.version_info >= (3, 13): +if sys.version_info >= (3, 14): from typing import TypedDict else: from typing_extensions import TypedDict +if sys.version_info >= (3, 12): + from typing import Protocol, runtime_checkable +else: + from typing_extensions import Protocol, runtime_checkable # noqa: F401 +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias if TYPE_CHECKING: from pathlib import Path @@ -356,6 +355,7 @@ def to_dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: class Parameter(_expr_core.OperatorMixin): """A Parameter object.""" + _schema: t.ClassVar[_TypeMap[Literal["object"]]] = {"type": "object"} _counter: int = 0 @classmethod @@ -458,6 +458,8 @@ def __or__(self, other: SchemaBase) -> SelectionPredicateComposition: class ParameterExpression(_expr_core.OperatorMixin): + _schema: t.ClassVar[_TypeMap[Literal["object"]]] = {"type": "object"} + def __init__(self, expr: IntoExpression) -> None: self.expr = expr @@ -472,6 +474,8 @@ def _from_expr(self, expr: IntoExpression) -> ParameterExpression: class SelectionExpression(_expr_core.OperatorMixin): + _schema: t.ClassVar[_TypeMap[Literal["object"]]] = {"type": "object"} + def __init__(self, expr: IntoExpression) -> None: self.expr = expr @@ -509,7 +513,7 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool: _PredicateType: TypeAlias = Union[ Parameter, core.Expr, - Map, + "_ConditionExtra", _TestPredicateType, _expr_core.OperatorMixin, ] @@ -534,12 +538,6 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool: ``` """ -_ConditionType: TypeAlias = t.Dict[str, Union[_TestPredicateType, Any]] -"""Intermediate type representing a converted `_PredicateType`. - -Prior to parsing any `_StatementType`. -""" - _LiteralValue: TypeAlias = Union[str, bool, float, int] """Primitive python value types.""" @@ -556,15 +554,15 @@ def _is_test_predicate(obj: Any) -> TypeIs[_TestPredicateType]: return isinstance(obj, (str, _expr_core.Expression, core.PredicateComposition)) -def _get_predicate_expr(p: Parameter) -> Optional[str | SchemaBase]: +def _get_predicate_expr(p: Parameter) -> Optional[_TestPredicateType]: # https://vega.github.io/vega-lite/docs/predicate.html return getattr(p.param, "expr", Undefined) def _predicate_to_condition( predicate: _PredicateType, *, empty: Optional[bool] = Undefined -) -> _ConditionType: - condition: _ConditionType +) -> _Condition: + condition: _Condition if isinstance(predicate, Parameter): predicate_expr = _get_predicate_expr(predicate) if predicate.param_type == "selection" or utils.is_undefined(predicate_expr): @@ -591,12 +589,12 @@ def _predicate_to_condition( def _condition_to_selection( - condition: _ConditionType, + condition: _Condition, if_true: _StatementType, if_false: _StatementType, **kwargs: Any, -) -> SchemaBase | dict[str, _ConditionType | Any]: - selection: SchemaBase | dict[str, _ConditionType | Any] +) -> SchemaBase | _Conditional[_Condition]: + selection: SchemaBase | _Conditional[_Condition] if isinstance(if_true, SchemaBase): if_true = if_true.to_dict() elif isinstance(if_true, str): @@ -610,56 +608,98 @@ def _condition_to_selection( else: if_true = utils.parse_shorthand(if_true) if_true.update(kwargs) - condition.update(if_true) + cond_mutable: Any = dict(condition) + cond_mutable.update(if_true) if isinstance(if_false, SchemaBase): # For the selection, the channel definitions all allow selections # already. So use this SchemaBase wrapper if possible. selection = if_false.copy() - selection.condition = condition + selection.condition = cond_mutable elif isinstance(if_false, (str, dict)): if isinstance(if_false, str): if_false = utils.parse_shorthand(if_false) if_false.update(kwargs) - selection = dict(condition=condition, **if_false) + selection = _Conditional(condition=cond_mutable, **if_false) # type: ignore[typeddict-item] else: raise TypeError(if_false) return selection -class _ConditionClosed(TypedDict, closed=True, total=False): # type: ignore[call-arg] +class _ConditionExtra(TypedDict, closed=True, total=False): # type: ignore[call-arg] # https://peps.python.org/pep-0728/ - # Parameter {"param", "value", "empty"} - # Predicate {"test", "value"} + # Likely a Field predicate empty: Optional[bool] param: Parameter | str test: _TestPredicateType value: Any + __extra_items__: _StatementType | OneOrSeq[_LiteralValue] -class _ConditionExtra(TypedDict, closed=True, total=False): # type: ignore[call-arg] +_Condition: TypeAlias = _ConditionExtra +""" +A singular, *possibly* non-chainable condition produced by ``.when()``. + +The default **permissive** representation. + +Allows arbitrary additional keys that *may* be present in a `Conditional Field`_ +but not a `Conditional Value`_. + +.. _Conditional Field: + https://vega.github.io/vega-lite/docs/condition.html#field +.. _Conditional Value: + https://vega.github.io/vega-lite/docs/condition.html#value +""" + + +class _ConditionClosed(TypedDict, closed=True, total=False): # type: ignore[call-arg] # https://peps.python.org/pep-0728/ - # Likely a Field predicate + # Parameter {"param", "value", "empty"} + # Predicate {"test", "value"} empty: Optional[bool] param: Parameter | str test: _TestPredicateType value: Any - __extra_items__: _StatementType | OneOrSeq[_LiteralValue] -_Condition: TypeAlias = _ConditionExtra -"""A singular, non-chainable condition produced by ``.when()``.""" - _Conditions: TypeAlias = t.List[_ConditionClosed] -"""Chainable conditions produced by ``.when()`` and ``Then.when()``.""" +""" +Chainable conditions produced by ``.when()`` and ``Then.when()``. + +All must be a `Conditional Value`_. + +.. _Conditional Value: + https://vega.github.io/vega-lite/docs/condition.html#value +""" _C = TypeVar("_C", _Conditions, _Condition) class _Conditional(TypedDict, t.Generic[_C], total=False): + """ + A dictionary representation of a conditional encoding or property. + + Parameters + ---------- + condition + One or more (predicate, statement) pairs which each form a condition. + value + An optional default value, used when no predicates were met. + """ + condition: Required[_C] value: Any +IntoCondition: TypeAlias = Union[ConditionLike, _Conditional[Any]] +""" +Anything that can be converted into a conditional encoding or property. + +Notes +----- +Represents all outputs from `when-then-otherwise` conditions, which are not ``SchemaBase`` types. +""" + + class _Value(TypedDict, closed=True, total=False): # type: ignore[call-arg] # https://peps.python.org/pep-0728/ value: Required[Any] @@ -687,6 +727,11 @@ def _is_condition_extra(obj: Any, *objs: Any, kwds: Map) -> TypeIs[_Condition]: return isinstance(obj, str) or any(_is_extra(obj, *objs, kwds=kwds)) +def _is_condition_closed(obj: Map) -> TypeIs[_ConditionClosed]: + """Return `True` if ``obj`` can be used in a chained condition.""" + return {"empty", "param", "test", "value"} >= obj.keys() + + def _parse_when_constraints( constraints: dict[str, _FieldEqualType], / ) -> Iterator[BinaryExpression]: @@ -754,7 +799,7 @@ def _parse_when( *more_predicates: _ComposablePredicateType, empty: Optional[bool], **constraints: _FieldEqualType, -) -> _ConditionType: +) -> _Condition: composed: _PredicateType if utils.is_undefined(predicate): if more_predicates or constraints: @@ -811,7 +856,7 @@ def _parse_otherwise( class _BaseWhen(Protocol): # NOTE: Temporary solution to non-SchemaBase copy - _condition: _ConditionType + _condition: _Condition def _when_then( self, statement: _StatementType, kwds: dict[str, Any], / @@ -835,7 +880,7 @@ class When(_BaseWhen): `polars.when `__ """ - def __init__(self, condition: _ConditionType, /) -> None: + def __init__(self, condition: _Condition, /) -> None: self._condition = condition def __repr__(self) -> str: @@ -890,7 +935,7 @@ def then(self, statement: _StatementType, /, **kwds: Any) -> Then[Any]: return Then(_Conditional(condition=[condition])) -class Then(SchemaBase, t.Generic[_C]): +class Then(ConditionLike, t.Generic[_C]): """ Utility class for ``when-then-otherwise`` conditions. @@ -906,11 +951,8 @@ class Then(SchemaBase, t.Generic[_C]): `polars.when `__ """ - _schema = {"type": "object"} - def __init__(self, conditions: _Conditional[_C], /) -> None: - super().__init__(**conditions) - self.condition: _C + self.condition: _C = conditions["condition"] @overload def otherwise(self, statement: _TSchemaBase, /, **kwds: Any) -> _TSchemaBase: ... @@ -1068,12 +1110,22 @@ def when( ) raise NotImplementedError(msg) - def to_dict(self, *args: Any, **kwds: Any) -> _Conditional[_C]: # type: ignore[override] - m = super().to_dict(*args, **kwds) - return _Conditional(condition=m["condition"]) + def to_dict(self, *args: Any, **kwds: Any) -> _Conditional[_C]: + return _Conditional(condition=self.condition.copy()) def __deepcopy__(self, memo: Any) -> Self: - return type(self)(_Conditional(condition=_deepcopy(self.condition))) + return type(self)(_Conditional(condition=_deepcopy(self.condition, memo))) + + def __repr__(self) -> str: + name = type(self).__name__ + COND = "condition: " + LB, RB = "{", "}" + if len(self.condition) == 1: + args = f"{COND}{self.condition!r}".replace("\n", "\n ") + else: + conds = "\n ".join(f"{c!r}" for c in self.condition) + args = f"{COND}[\n " f"{conds}\n ]" + return f"{name}({LB}\n {args}\n{RB})" class ChainedWhen(_BaseWhen): @@ -1091,7 +1143,7 @@ class ChainedWhen(_BaseWhen): def __init__( self, - condition: _ConditionType, + condition: _Condition, conditions: _Conditional[_Conditions], /, ) -> None: @@ -1144,9 +1196,18 @@ def then(self, statement: _StatementType, /, **kwds: Any) -> Then[_Conditions]: ) """ condition = self._when_then(statement, kwds) - conditions = self._conditions.copy() - conditions["condition"].append(condition) - return Then(conditions) + if _is_condition_closed(condition): + conditions = self._conditions.copy() + conditions["condition"].append(condition) + return Then(conditions) + else: + cond = _reveal_parsed_shorthand(condition) + msg = ( + f"Chained conditions cannot be mixed with field conditions.\n" + f"Shorthand {statement!r} expanded to {cond!r}\n\n" + f"Use `alt.value({statement!r})` if this is not a shorthand string." + ) + raise TypeError(msg) def when( @@ -1663,7 +1724,7 @@ def condition( *, empty: Optional[bool] = ..., **kwargs: Any, -) -> dict[str, _ConditionType | Any]: ... +) -> _Conditional[_Condition]: ... @overload def condition( predicate: _PredicateType, @@ -1672,7 +1733,7 @@ def condition( *, empty: Optional[bool] = ..., **kwargs: Any, -) -> dict[str, _ConditionType | Any]: ... +) -> _Conditional[_Condition]: ... @overload def condition( predicate: _PredicateType, if_true: str, if_false: str, **kwargs: Any @@ -1685,7 +1746,7 @@ def condition( *, empty: Optional[bool] = Undefined, **kwargs: Any, -) -> SchemaBase | dict[str, _ConditionType | Any]: +) -> SchemaBase | _Conditional[_Condition]: """ A conditional attribute or encoding. @@ -4907,7 +4968,7 @@ def remove_prop(subchart: ChartType, prop: str) -> ChartType: # or it must be Undefined or identical to proceed. output_dict[prop] = chart[prop] else: - msg = f"There are inconsistent values {values} for {prop}" + msg = f"There are inconsistent values {values} for {prop}" # pyright: ignore[reportPossiblyUnboundVariable] raise ValueError(msg) subcharts = [remove_prop(c, prop) for c in subcharts] diff --git a/altair/vegalite/v5/schema/channels.py b/altair/vegalite/v5/schema/channels.py index 116f68d1c..ac645fd85 100644 --- a/altair/vegalite/v5/schema/channels.py +++ b/altair/vegalite/v5/schema/channels.py @@ -29,6 +29,7 @@ from altair import Parameter, SchemaBase from altair.typing import Optional + from altair.vegalite.v5.api import IntoCondition __all__ = [ @@ -24914,122 +24915,132 @@ def __init__(self, value, **kwds): super().__init__(value=value, **kwds) -ChannelAngle: TypeAlias = Union[str, Angle, Map, AngleDatum, AngleValue] -ChannelColor: TypeAlias = Union[str, Color, Map, ColorDatum, ColorValue] -ChannelColumn: TypeAlias = Union[str, Column, Map] -ChannelDescription: TypeAlias = Union[str, Description, Map, DescriptionValue] -ChannelDetail: TypeAlias = OneOrSeq[Union[str, Detail, Map]] -ChannelFacet: TypeAlias = Union[str, Facet, Map] -ChannelFill: TypeAlias = Union[str, Fill, Map, FillDatum, FillValue] -ChannelFillOpacity: TypeAlias = Union[ - str, FillOpacity, Map, FillOpacityDatum, FillOpacityValue +AnyAngle: TypeAlias = Union[Angle, AngleDatum, AngleValue] +AnyColor: TypeAlias = Union[Color, ColorDatum, ColorValue] +AnyDescription: TypeAlias = Union[Description, DescriptionValue] +AnyFill: TypeAlias = Union[Fill, FillDatum, FillValue] +AnyFillOpacity: TypeAlias = Union[FillOpacity, FillOpacityDatum, FillOpacityValue] +AnyHref: TypeAlias = Union[Href, HrefValue] +AnyLatitude: TypeAlias = Union[Latitude, LatitudeDatum] +AnyLatitude2: TypeAlias = Union[Latitude2, Latitude2Datum, Latitude2Value] +AnyLongitude: TypeAlias = Union[Longitude, LongitudeDatum] +AnyLongitude2: TypeAlias = Union[Longitude2, Longitude2Datum, Longitude2Value] +AnyOpacity: TypeAlias = Union[Opacity, OpacityDatum, OpacityValue] +AnyOrder: TypeAlias = Union[Order, OrderValue] +AnyRadius: TypeAlias = Union[Radius, RadiusDatum, RadiusValue] +AnyRadius2: TypeAlias = Union[Radius2, Radius2Datum, Radius2Value] +AnyShape: TypeAlias = Union[Shape, ShapeDatum, ShapeValue] +AnySize: TypeAlias = Union[Size, SizeDatum, SizeValue] +AnyStroke: TypeAlias = Union[Stroke, StrokeDatum, StrokeValue] +AnyStrokeDash: TypeAlias = Union[StrokeDash, StrokeDashDatum, StrokeDashValue] +AnyStrokeOpacity: TypeAlias = Union[ + StrokeOpacity, StrokeOpacityDatum, StrokeOpacityValue ] -ChannelHref: TypeAlias = Union[str, Href, Map, HrefValue] -ChannelKey: TypeAlias = Union[str, Key, Map] -ChannelLatitude: TypeAlias = Union[str, Latitude, Map, LatitudeDatum] -ChannelLatitude2: TypeAlias = Union[str, Latitude2, Map, Latitude2Datum, Latitude2Value] -ChannelLongitude: TypeAlias = Union[str, Longitude, Map, LongitudeDatum] -ChannelLongitude2: TypeAlias = Union[ - str, Longitude2, Map, Longitude2Datum, Longitude2Value -] -ChannelOpacity: TypeAlias = Union[str, Opacity, Map, OpacityDatum, OpacityValue] -ChannelOrder: TypeAlias = OneOrSeq[Union[str, Order, Map, OrderValue]] -ChannelRadius: TypeAlias = Union[str, Radius, Map, RadiusDatum, RadiusValue] -ChannelRadius2: TypeAlias = Union[str, Radius2, Map, Radius2Datum, Radius2Value] -ChannelRow: TypeAlias = Union[str, Row, Map] -ChannelShape: TypeAlias = Union[str, Shape, Map, ShapeDatum, ShapeValue] -ChannelSize: TypeAlias = Union[str, Size, Map, SizeDatum, SizeValue] -ChannelStroke: TypeAlias = Union[str, Stroke, Map, StrokeDatum, StrokeValue] -ChannelStrokeDash: TypeAlias = Union[ - str, StrokeDash, Map, StrokeDashDatum, StrokeDashValue -] -ChannelStrokeOpacity: TypeAlias = Union[ - str, StrokeOpacity, Map, StrokeOpacityDatum, StrokeOpacityValue -] -ChannelStrokeWidth: TypeAlias = Union[ - str, StrokeWidth, Map, StrokeWidthDatum, StrokeWidthValue -] -ChannelText: TypeAlias = Union[str, Text, Map, TextDatum, TextValue] -ChannelTheta: TypeAlias = Union[str, Theta, Map, ThetaDatum, ThetaValue] -ChannelTheta2: TypeAlias = Union[str, Theta2, Map, Theta2Datum, Theta2Value] -ChannelTooltip: TypeAlias = OneOrSeq[Union[str, Tooltip, Map, TooltipValue]] -ChannelUrl: TypeAlias = Union[str, Url, Map, UrlValue] -ChannelX: TypeAlias = Union[str, X, Map, XDatum, XValue] -ChannelX2: TypeAlias = Union[str, X2, Map, X2Datum, X2Value] -ChannelXError: TypeAlias = Union[str, XError, Map, XErrorValue] -ChannelXError2: TypeAlias = Union[str, XError2, Map, XError2Value] -ChannelXOffset: TypeAlias = Union[str, XOffset, Map, XOffsetDatum, XOffsetValue] -ChannelY: TypeAlias = Union[str, Y, Map, YDatum, YValue] -ChannelY2: TypeAlias = Union[str, Y2, Map, Y2Datum, Y2Value] -ChannelYError: TypeAlias = Union[str, YError, Map, YErrorValue] -ChannelYError2: TypeAlias = Union[str, YError2, Map, YError2Value] -ChannelYOffset: TypeAlias = Union[str, YOffset, Map, YOffsetDatum, YOffsetValue] +AnyStrokeWidth: TypeAlias = Union[StrokeWidth, StrokeWidthDatum, StrokeWidthValue] +AnyText: TypeAlias = Union[Text, TextDatum, TextValue] +AnyTheta: TypeAlias = Union[Theta, ThetaDatum, ThetaValue] +AnyTheta2: TypeAlias = Union[Theta2, Theta2Datum, Theta2Value] +AnyTooltip: TypeAlias = Union[Tooltip, TooltipValue] +AnyUrl: TypeAlias = Union[Url, UrlValue] +AnyX: TypeAlias = Union[X, XDatum, XValue] +AnyX2: TypeAlias = Union[X2, X2Datum, X2Value] +AnyXError: TypeAlias = Union[XError, XErrorValue] +AnyXError2: TypeAlias = Union[XError2, XError2Value] +AnyXOffset: TypeAlias = Union[XOffset, XOffsetDatum, XOffsetValue] +AnyY: TypeAlias = Union[Y, YDatum, YValue] +AnyY2: TypeAlias = Union[Y2, Y2Datum, Y2Value] +AnyYError: TypeAlias = Union[YError, YErrorValue] +AnyYError2: TypeAlias = Union[YError2, YError2Value] +AnyYOffset: TypeAlias = Union[YOffset, YOffsetDatum, YOffsetValue] + +ChannelAngle: TypeAlias = Union[str, AnyAngle, "IntoCondition", Map] +ChannelColor: TypeAlias = Union[str, AnyColor, "IntoCondition", Map] +ChannelColumn: TypeAlias = Union[str, Column, "IntoCondition", Map] +ChannelDescription: TypeAlias = Union[str, AnyDescription, "IntoCondition", Map] +ChannelDetail: TypeAlias = OneOrSeq[Union[str, Detail, "IntoCondition", Map]] +ChannelFacet: TypeAlias = Union[str, Facet, "IntoCondition", Map] +ChannelFill: TypeAlias = Union[str, AnyFill, "IntoCondition", Map] +ChannelFillOpacity: TypeAlias = Union[str, AnyFillOpacity, "IntoCondition", Map] +ChannelHref: TypeAlias = Union[str, AnyHref, "IntoCondition", Map] +ChannelKey: TypeAlias = Union[str, Key, "IntoCondition", Map] +ChannelLatitude: TypeAlias = Union[str, AnyLatitude, "IntoCondition", Map] +ChannelLatitude2: TypeAlias = Union[str, AnyLatitude2, "IntoCondition", Map] +ChannelLongitude: TypeAlias = Union[str, AnyLongitude, "IntoCondition", Map] +ChannelLongitude2: TypeAlias = Union[str, AnyLongitude2, "IntoCondition", Map] +ChannelOpacity: TypeAlias = Union[str, AnyOpacity, "IntoCondition", Map] +ChannelOrder: TypeAlias = OneOrSeq[Union[str, AnyOrder, "IntoCondition", Map]] +ChannelRadius: TypeAlias = Union[str, AnyRadius, "IntoCondition", Map] +ChannelRadius2: TypeAlias = Union[str, AnyRadius2, "IntoCondition", Map] +ChannelRow: TypeAlias = Union[str, Row, "IntoCondition", Map] +ChannelShape: TypeAlias = Union[str, AnyShape, "IntoCondition", Map] +ChannelSize: TypeAlias = Union[str, AnySize, "IntoCondition", Map] +ChannelStroke: TypeAlias = Union[str, AnyStroke, "IntoCondition", Map] +ChannelStrokeDash: TypeAlias = Union[str, AnyStrokeDash, "IntoCondition", Map] +ChannelStrokeOpacity: TypeAlias = Union[str, AnyStrokeOpacity, "IntoCondition", Map] +ChannelStrokeWidth: TypeAlias = Union[str, AnyStrokeWidth, "IntoCondition", Map] +ChannelText: TypeAlias = Union[str, AnyText, "IntoCondition", Map] +ChannelTheta: TypeAlias = Union[str, AnyTheta, "IntoCondition", Map] +ChannelTheta2: TypeAlias = Union[str, AnyTheta2, "IntoCondition", Map] +ChannelTooltip: TypeAlias = OneOrSeq[Union[str, AnyTooltip, "IntoCondition", Map]] +ChannelUrl: TypeAlias = Union[str, AnyUrl, "IntoCondition", Map] +ChannelX: TypeAlias = Union[str, AnyX, "IntoCondition", Map] +ChannelX2: TypeAlias = Union[str, AnyX2, "IntoCondition", Map] +ChannelXError: TypeAlias = Union[str, AnyXError, "IntoCondition", Map] +ChannelXError2: TypeAlias = Union[str, AnyXError2, "IntoCondition", Map] +ChannelXOffset: TypeAlias = Union[str, AnyXOffset, "IntoCondition", Map] +ChannelY: TypeAlias = Union[str, AnyY, "IntoCondition", Map] +ChannelY2: TypeAlias = Union[str, AnyY2, "IntoCondition", Map] +ChannelYError: TypeAlias = Union[str, AnyYError, "IntoCondition", Map] +ChannelYError2: TypeAlias = Union[str, AnyYError2, "IntoCondition", Map] +ChannelYOffset: TypeAlias = Union[str, AnyYOffset, "IntoCondition", Map] class _EncodingMixin: def encode( self, *args: Any, - angle: Optional[str | Angle | Map | AngleDatum | AngleValue] = Undefined, - color: Optional[str | Color | Map | ColorDatum | ColorValue] = Undefined, - column: Optional[str | Column | Map] = Undefined, - description: Optional[str | Description | Map | DescriptionValue] = Undefined, - detail: Optional[OneOrSeq[str | Detail | Map]] = Undefined, - facet: Optional[str | Facet | Map] = Undefined, - fill: Optional[str | Fill | Map | FillDatum | FillValue] = Undefined, - fillOpacity: Optional[ - str | FillOpacity | Map | FillOpacityDatum | FillOpacityValue - ] = Undefined, - href: Optional[str | Href | Map | HrefValue] = Undefined, - key: Optional[str | Key | Map] = Undefined, - latitude: Optional[str | Latitude | Map | LatitudeDatum] = Undefined, - latitude2: Optional[ - str | Latitude2 | Map | Latitude2Datum | Latitude2Value - ] = Undefined, - longitude: Optional[str | Longitude | Map | LongitudeDatum] = Undefined, - longitude2: Optional[ - str | Longitude2 | Map | Longitude2Datum | Longitude2Value - ] = Undefined, - opacity: Optional[ - str | Opacity | Map | OpacityDatum | OpacityValue - ] = Undefined, - order: Optional[OneOrSeq[str | Order | Map | OrderValue]] = Undefined, - radius: Optional[str | Radius | Map | RadiusDatum | RadiusValue] = Undefined, - radius2: Optional[ - str | Radius2 | Map | Radius2Datum | Radius2Value - ] = Undefined, - row: Optional[str | Row | Map] = Undefined, - shape: Optional[str | Shape | Map | ShapeDatum | ShapeValue] = Undefined, - size: Optional[str | Size | Map | SizeDatum | SizeValue] = Undefined, - stroke: Optional[str | Stroke | Map | StrokeDatum | StrokeValue] = Undefined, - strokeDash: Optional[ - str | StrokeDash | Map | StrokeDashDatum | StrokeDashValue - ] = Undefined, + angle: Optional[str | AnyAngle | IntoCondition | Map] = Undefined, + color: Optional[str | AnyColor | IntoCondition | Map] = Undefined, + column: Optional[str | Column | IntoCondition | Map] = Undefined, + description: Optional[str | AnyDescription | IntoCondition | Map] = Undefined, + detail: Optional[OneOrSeq[str | Detail | IntoCondition | Map]] = Undefined, + facet: Optional[str | Facet | IntoCondition | Map] = Undefined, + fill: Optional[str | AnyFill | IntoCondition | Map] = Undefined, + fillOpacity: Optional[str | AnyFillOpacity | IntoCondition | Map] = Undefined, + href: Optional[str | AnyHref | IntoCondition | Map] = Undefined, + key: Optional[str | Key | IntoCondition | Map] = Undefined, + latitude: Optional[str | AnyLatitude | IntoCondition | Map] = Undefined, + latitude2: Optional[str | AnyLatitude2 | IntoCondition | Map] = Undefined, + longitude: Optional[str | AnyLongitude | IntoCondition | Map] = Undefined, + longitude2: Optional[str | AnyLongitude2 | IntoCondition | Map] = Undefined, + opacity: Optional[str | AnyOpacity | IntoCondition | Map] = Undefined, + order: Optional[OneOrSeq[str | AnyOrder | IntoCondition | Map]] = Undefined, + radius: Optional[str | AnyRadius | IntoCondition | Map] = Undefined, + radius2: Optional[str | AnyRadius2 | IntoCondition | Map] = Undefined, + row: Optional[str | Row | IntoCondition | Map] = Undefined, + shape: Optional[str | AnyShape | IntoCondition | Map] = Undefined, + size: Optional[str | AnySize | IntoCondition | Map] = Undefined, + stroke: Optional[str | AnyStroke | IntoCondition | Map] = Undefined, + strokeDash: Optional[str | AnyStrokeDash | IntoCondition | Map] = Undefined, strokeOpacity: Optional[ - str | StrokeOpacity | Map | StrokeOpacityDatum | StrokeOpacityValue - ] = Undefined, - strokeWidth: Optional[ - str | StrokeWidth | Map | StrokeWidthDatum | StrokeWidthValue - ] = Undefined, - text: Optional[str | Text | Map | TextDatum | TextValue] = Undefined, - theta: Optional[str | Theta | Map | ThetaDatum | ThetaValue] = Undefined, - theta2: Optional[str | Theta2 | Map | Theta2Datum | Theta2Value] = Undefined, - tooltip: Optional[OneOrSeq[str | Tooltip | Map | TooltipValue]] = Undefined, - url: Optional[str | Url | Map | UrlValue] = Undefined, - x: Optional[str | X | Map | XDatum | XValue] = Undefined, - x2: Optional[str | X2 | Map | X2Datum | X2Value] = Undefined, - xError: Optional[str | XError | Map | XErrorValue] = Undefined, - xError2: Optional[str | XError2 | Map | XError2Value] = Undefined, - xOffset: Optional[ - str | XOffset | Map | XOffsetDatum | XOffsetValue - ] = Undefined, - y: Optional[str | Y | Map | YDatum | YValue] = Undefined, - y2: Optional[str | Y2 | Map | Y2Datum | Y2Value] = Undefined, - yError: Optional[str | YError | Map | YErrorValue] = Undefined, - yError2: Optional[str | YError2 | Map | YError2Value] = Undefined, - yOffset: Optional[ - str | YOffset | Map | YOffsetDatum | YOffsetValue - ] = Undefined, + str | AnyStrokeOpacity | IntoCondition | Map + ] = Undefined, + strokeWidth: Optional[str | AnyStrokeWidth | IntoCondition | Map] = Undefined, + text: Optional[str | AnyText | IntoCondition | Map] = Undefined, + theta: Optional[str | AnyTheta | IntoCondition | Map] = Undefined, + theta2: Optional[str | AnyTheta2 | IntoCondition | Map] = Undefined, + tooltip: Optional[OneOrSeq[str | AnyTooltip | IntoCondition | Map]] = Undefined, + url: Optional[str | AnyUrl | IntoCondition | Map] = Undefined, + x: Optional[str | AnyX | IntoCondition | Map] = Undefined, + x2: Optional[str | AnyX2 | IntoCondition | Map] = Undefined, + xError: Optional[str | AnyXError | IntoCondition | Map] = Undefined, + xError2: Optional[str | AnyXError2 | IntoCondition | Map] = Undefined, + xOffset: Optional[str | AnyXOffset | IntoCondition | Map] = Undefined, + y: Optional[str | AnyY | IntoCondition | Map] = Undefined, + y2: Optional[str | AnyY2 | IntoCondition | Map] = Undefined, + yError: Optional[str | AnyYError | IntoCondition | Map] = Undefined, + yError2: Optional[str | AnyYError2 | IntoCondition | Map] = Undefined, + yOffset: Optional[str | AnyYOffset | IntoCondition | Map] = Undefined, ) -> Self: """ Map properties of the data to visual properties of the chart (see :class:`FacetedEncoding`). @@ -25460,43 +25471,43 @@ class EncodeKwds(TypedDict, total=False): Offset of y-position of the marks """ - angle: str | Angle | Map | AngleDatum | AngleValue - color: str | Color | Map | ColorDatum | ColorValue - column: str | Column | Map - description: str | Description | Map | DescriptionValue - detail: OneOrSeq[str | Detail | Map] - facet: str | Facet | Map - fill: str | Fill | Map | FillDatum | FillValue - fillOpacity: str | FillOpacity | Map | FillOpacityDatum | FillOpacityValue - href: str | Href | Map | HrefValue - key: str | Key | Map - latitude: str | Latitude | Map | LatitudeDatum - latitude2: str | Latitude2 | Map | Latitude2Datum | Latitude2Value - longitude: str | Longitude | Map | LongitudeDatum - longitude2: str | Longitude2 | Map | Longitude2Datum | Longitude2Value - opacity: str | Opacity | Map | OpacityDatum | OpacityValue - order: OneOrSeq[str | Order | Map | OrderValue] - radius: str | Radius | Map | RadiusDatum | RadiusValue - radius2: str | Radius2 | Map | Radius2Datum | Radius2Value - row: str | Row | Map - shape: str | Shape | Map | ShapeDatum | ShapeValue - size: str | Size | Map | SizeDatum | SizeValue - stroke: str | Stroke | Map | StrokeDatum | StrokeValue - strokeDash: str | StrokeDash | Map | StrokeDashDatum | StrokeDashValue - strokeOpacity: str | StrokeOpacity | Map | StrokeOpacityDatum | StrokeOpacityValue - strokeWidth: str | StrokeWidth | Map | StrokeWidthDatum | StrokeWidthValue - text: str | Text | Map | TextDatum | TextValue - theta: str | Theta | Map | ThetaDatum | ThetaValue - theta2: str | Theta2 | Map | Theta2Datum | Theta2Value - tooltip: OneOrSeq[str | Tooltip | Map | TooltipValue] - url: str | Url | Map | UrlValue - x: str | X | Map | XDatum | XValue - x2: str | X2 | Map | X2Datum | X2Value - xError: str | XError | Map | XErrorValue - xError2: str | XError2 | Map | XError2Value - xOffset: str | XOffset | Map | XOffsetDatum | XOffsetValue - y: str | Y | Map | YDatum | YValue - y2: str | Y2 | Map | Y2Datum | Y2Value - yError: str | YError | Map | YErrorValue - yError2: str | YError2 | Map | YError2Value - yOffset: str | YOffset | Map | YOffsetDatum | YOffsetValue + angle: str | AnyAngle | IntoCondition | Map + color: str | AnyColor | IntoCondition | Map + column: str | Column | IntoCondition | Map + description: str | AnyDescription | IntoCondition | Map + detail: OneOrSeq[str | Detail | IntoCondition | Map] + facet: str | Facet | IntoCondition | Map + fill: str | AnyFill | IntoCondition | Map + fillOpacity: str | AnyFillOpacity | IntoCondition | Map + href: str | AnyHref | IntoCondition | Map + key: str | Key | IntoCondition | Map + latitude: str | AnyLatitude | IntoCondition | Map + latitude2: str | AnyLatitude2 | IntoCondition | Map + longitude: str | AnyLongitude | IntoCondition | Map + longitude2: str | AnyLongitude2 | IntoCondition | Map + opacity: str | AnyOpacity | IntoCondition | Map + order: OneOrSeq[str | AnyOrder | IntoCondition | Map] + radius: str | AnyRadius | IntoCondition | Map + radius2: str | AnyRadius2 | IntoCondition | Map + row: str | Row | IntoCondition | Map + shape: str | AnyShape | IntoCondition | Map + size: str | AnySize | IntoCondition | Map + stroke: str | AnyStroke | IntoCondition | Map + strokeDash: str | AnyStrokeDash | IntoCondition | Map + strokeOpacity: str | AnyStrokeOpacity | IntoCondition | Map + strokeWidth: str | AnyStrokeWidth | IntoCondition | Map + text: str | AnyText | IntoCondition | Map + theta: str | AnyTheta | IntoCondition | Map + theta2: str | AnyTheta2 | IntoCondition | Map + tooltip: OneOrSeq[str | AnyTooltip | IntoCondition | Map] + url: str | AnyUrl | IntoCondition | Map + x: str | AnyX | IntoCondition | Map + x2: str | AnyX2 | IntoCondition | Map + xError: str | AnyXError | IntoCondition | Map + xError2: str | AnyXError2 | IntoCondition | Map + xOffset: str | AnyXOffset | IntoCondition | Map + y: str | AnyY | IntoCondition | Map + y2: str | AnyY2 | IntoCondition | Map + yError: str | AnyYError | IntoCondition | Map + yError2: str | AnyYError2 | IntoCondition | Map + yOffset: str | AnyYOffset | IntoCondition | Map diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index ee5b73370..a7d2f1c69 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -12,6 +12,7 @@ import tempfile from datetime import date, datetime from importlib.metadata import version as importlib_version +from typing import TYPE_CHECKING import ibis import jsonschema @@ -22,7 +23,7 @@ from packaging.version import Version import altair as alt -from altair.utils.schemapi import Optional, Undefined +from altair.utils.schemapi import Optional, SchemaValidationError, Undefined from tests import skip_requires_vl_convert, slow try: @@ -30,6 +31,10 @@ except ImportError: vlc = None +if TYPE_CHECKING: + from altair.vegalite.v5.api import _Conditional, _Conditions + from altair.vegalite.v5.schema._typing import Map + ibis.set_backend("polars") PANDAS_VERSION = Version(importlib_version("pandas")) @@ -420,7 +425,7 @@ def test_when_then_otherwise() -> None: when_then = alt.when(select).then(alt.value(2, empty=False)) when_then_otherwise = when_then.otherwise(alt.value(0)) - expected = alt.condition(select, alt.value(2, empty=False), alt.value(0)) + expected = dict(alt.condition(select, alt.value(2, empty=False), alt.value(0))) with pytest.raises(TypeError, match="list"): when_then.otherwise([1, 2, 3]) # type: ignore @@ -630,14 +635,24 @@ def test_when_multiple_fields(): with pytest.raises(TypeError, match=chain_mixed_msg): when.then("field_1:Q").when(Genre="pop") + chained_when = when.then(alt.value(5)).when( + alt.selection_point(fields=["b"]) | brush, empty=False, b=63812 + ) + + chain_then_msg = re.compile( + r"Chained.+mixed.+field.+min\(foo\):Q.+'aggregate': 'min', 'field': 'foo', 'type': 'quantitative'", + flags=re.DOTALL, + ) + + with pytest.raises(TypeError, match=chain_then_msg): + chained_when.then("min(foo):Q") + chain_otherwise_msg = re.compile( r"Chained.+mixed.+field.+AggregatedFieldDef.+'this_field_here'", flags=re.DOTALL, ) with pytest.raises(TypeError, match=chain_otherwise_msg): - when.then(alt.value(5)).when( - alt.selection_point(fields=["b"]) | brush, empty=False, b=63812 - ).then("min(foo):Q").otherwise( + chained_when.then(alt.value(2)).otherwise( alt.AggregatedFieldDef( "argmax", field="field_9", **{"as": "this_field_here"} ) @@ -645,21 +660,43 @@ def test_when_multiple_fields(): def test_when_typing(cars) -> None: - color = ( - alt.when(alt.datum.Weight_in_lbs >= 3500) - .then(alt.value("black")) - .otherwise(alt.value("white")) - ) - source = cars - chart = ( # noqa: F841 - alt.Chart(source) - .mark_rect() - .encode( - x=alt.X("Cylinders:N").axis(labelColor=color), - y=alt.Y("Origin:N", axis=alt.Axis(tickColor=color)), - color=color, + chart = alt.Chart(cars).mark_rect() + predicate = alt.datum.Weight_in_lbs >= 3500 + statement = alt.value("black") + default = alt.value("white") + + then: alt.Then[_Conditions] = alt.when(predicate).then(statement) + otherwise: _Conditional[_Conditions] = then.otherwise(default) + condition: Map = alt.condition(predicate, statement, default) + + # NOTE: both `condition()` and `when-then-otherwise` are allowed in these three locations + chart.encode( + color=condition, + x=alt.X("Cylinders:N").axis(labelColor=condition), + y=alt.Y("Origin:N", axis=alt.Axis(tickColor=condition)), + ).to_dict() + + chart.encode( + color=otherwise, + x=alt.X("Cylinders:N").axis(labelColor=otherwise), + y=alt.Y("Origin:N", axis=alt.Axis(tickColor=otherwise)), + ).to_dict() + + with pytest.raises(SchemaValidationError): + # NOTE: `when-then` is allowed as an encoding, but not as a `ConditionalAxisProperty` + # The latter fails validation since it does not have a default `value` + chart.encode( + color=then, + x=alt.X("Cylinders:N").axis(labelColor=then), # type: ignore[call-overload] + y=alt.Y("Origin:N", axis=alt.Axis(labelColor=then)), # type: ignore[arg-type] ) - ) + + # NOTE: Passing validation then requires an `.otherwise()` **only** for the property cases + chart.encode( + color=then, + x=alt.X("Cylinders:N").axis(labelColor=otherwise), + y=alt.Y("Origin:N", axis=alt.Axis(labelColor=otherwise)), + ).to_dict() @pytest.mark.parametrize( @@ -730,7 +767,7 @@ def test_when_then_interactive() -> None: .encode( x="IMDB_Rating:Q", y="Rotten_Tomatoes_Rating:Q", - color=alt.when(predicate).then(alt.value("grey")), # type: ignore[arg-type] + color=alt.when(predicate).then(alt.value("grey")), ) ) assert chart.interactive() diff --git a/tools/generate_api_docs.py b/tools/generate_api_docs.py index d3771d6b7..dc93aba24 100644 --- a/tools/generate_api_docs.py +++ b/tools/generate_api_docs.py @@ -112,18 +112,17 @@ def toplevel_charts() -> list[str]: def encoding_wrappers() -> list[str]: - return sorted(iter_objects(alt.channels, restrict_to_subclass=alt.SchemaBase)) + return sorted(iter_objects(alt.channels, restrict_to_subclass=alt.SchemaBase)) # type: ignore[attr-defined] def api_functions() -> list[str]: # Exclude `typing` functions/SpecialForm(s) - altair_api_functions = [ - obj_name - for obj_name in iter_objects(alt.api, restrict_to_type=types.FunctionType) # type: ignore[attr-defined] - if obj_name - not in {"cast", "overload", "NamedTuple", "TypedDict", "is_chart_type"} - ] - return sorted(altair_api_functions) + KEEP = set(alt.api.__all__) - set(alt.typing.__all__) # type: ignore[attr-defined] + return sorted( + name + for name in iter_objects(alt.api, restrict_to_type=types.FunctionType) # type: ignore[attr-defined] + if name in KEEP + ) def api_classes() -> list[str]: @@ -132,7 +131,7 @@ def api_classes() -> list[str]: def type_hints() -> list[str]: - return [s for s in sorted(iter_objects(alt.typing)) if s != "annotations"] + return sorted(s for s in iter_objects(alt.typing) if s in alt.typing.__all__) def lowlevel_wrappers() -> list[str]: diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index 7bac492d7..b3598c1c9 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -266,6 +266,8 @@ def func( """ ''' +INTO_CONDITION: Literal["IntoCondition"] = "IntoCondition" + class SchemaGenerator(codegen.SchemaGenerator): schema_class_template = textwrap.dedent( @@ -693,6 +695,7 @@ def generate_vegalite_channel_wrappers( "from altair import Parameter, SchemaBase", "from altair.typing import Optional", "from typing_extensions import Self", + f"from altair.vegalite.v5.api import {INTO_CONDITION}", ), "\n" f"__all__ = {sorted(all_)}\n", CHANNEL_MIXINS, @@ -900,20 +903,33 @@ def generate_encoding_artifacts( - but this translates poorly to an IDE - `info.supports_arrays` """ + PREFIX_INTERNAL = "Any" + PREFIX_EXPORT = "Channel" signature_args: list[str] = ["self", "*args: Any"] - type_aliases: list[str] = [] + internal_aliases: list[str] = [] + export_aliases: list[str] = [] typed_dict_args: list[str] = [] signature_doc_params: list[str] = ["", "Parameters", "----------"] typed_dict_doc_params: list[str] = ["", "Parameters", "----------"] for channel, info in channel_infos.items(): - alias_name: str = f"Channel{channel[0].upper()}{channel[1:]}" + channel_name = f"{channel[0].upper()}{channel[1:]}" - it: Iterator[str] = info.all_names + names = list(info.all_names) it_rst_names: Iterator[str] = (rst_syntax_for_class(c) for c in info.all_names) docstring_types: list[str] = ["str", next(it_rst_names), "Dict"] - tp_inner: str = ", ".join(chain(("str", next(it), "Map"), it)) + if len(names) > 1: + # NOTE: Another level of internal aliases are generated, for channels w/ 2-3 types. + # These are represent only types defined in `channels.py` and not the full range accepted. + any_name = f"{PREFIX_INTERNAL}{channel_name}" + internal_aliases.append( + f"{any_name}: TypeAlias = Union[{', '.join(names)}]" + ) + tp_inner: str = ", ".join(("str", any_name, f"{INTO_CONDITION!r}", "Map")) + else: + tp_inner = ", ".join(("str", names[0], f"{INTO_CONDITION!r}", "Map")) + tp_inner = f"Union[{tp_inner}]" if info.supports_arrays: @@ -922,7 +938,7 @@ def generate_encoding_artifacts( doc_types_flat: str = ", ".join(chain(docstring_types, it_rst_names)) - type_aliases.append(f"{alias_name}: TypeAlias = {tp_inner}") + export_aliases.append(f"{PREFIX_EXPORT}{channel_name}: TypeAlias = {tp_inner}") # We use the full type hints instead of the alias in the signatures below # as IDEs such as VS Code would else show the name of the alias instead # of the expanded full type hints. The later are more useful to users. @@ -942,7 +958,13 @@ def generate_encoding_artifacts( channels="\n ".join(typed_dict_args), docstring=indent_docstring(typed_dict_doc_params, indent_level=4, lstrip=False), ) - artifacts: Iterable[str] = *type_aliases, method, typed_dict + artifacts: Iterable[str] = ( + *internal_aliases, + "", + *export_aliases, + method, + typed_dict, + ) yield from artifacts diff --git a/tools/schemapi/schemapi.py b/tools/schemapi/schemapi.py index 5140073ad..3089886eb 100644 --- a/tools/schemapi/schemapi.py +++ b/tools/schemapi/schemapi.py @@ -16,10 +16,12 @@ Any, Dict, Final, + Generic, Iterable, Iterator, List, Literal, + Mapping, Sequence, TypeVar, Union, @@ -39,6 +41,11 @@ # not yet be fully instantiated in case your code is being executed during import time from altair import vegalite +if sys.version_info >= (3, 12): + from typing import Protocol, TypeAliasType, runtime_checkable +else: + from typing_extensions import Protocol, TypeAliasType, runtime_checkable + if TYPE_CHECKING: from types import ModuleType from typing import ClassVar @@ -522,11 +529,7 @@ def _todict(obj: Any, context: dict[str, Any] | None, np_opt: Any, pd_opt: Any) for k, v in obj.items() if v is not Undefined } - elif ( - hasattr(obj, "to_dict") - and (module_name := obj.__module__) - and module_name.startswith("altair") - ): + elif isinstance(obj, SchemaLike): return obj.to_dict() elif pd_opt is not None and isinstance(obj, pd_opt.Timestamp): return pd_opt.Timestamp(obj).isoformat() @@ -787,6 +790,95 @@ def _get_default_error_message( return message +_JSON_VT_co = TypeVar( + "_JSON_VT_co", + Literal["string"], + Literal["object"], + Literal["array"], + covariant=True, +) +""" +One of a subset of JSON Schema `primitive types`_: + + ["string", "object", "array"] + +.. _primitive types: + https://json-schema.org/draft-07/json-schema-validation#rfc.section.6.1.1 +""" + +_TypeMap = TypeAliasType( + "_TypeMap", Mapping[Literal["type"], _JSON_VT_co], type_params=(_JSON_VT_co,) +) +""" +A single item JSON Schema using the `type`_ keyword. + +This may represent **one of**: + + {"type": "string"} + {"type": "object"} + {"type": "array"} + +.. _type: + https://json-schema.org/understanding-json-schema/reference/type +""" + +# NOTE: Type checkers want opposing things: +# - `mypy` : Covariant type variable "_JSON_VT_co" used in protocol where invariant one is expected [misc] +# - `pyright`: Type variable "_JSON_VT_co" used in generic protocol "SchemaLike" should be covariant [reportInvalidTypeVarUse] +# Siding with `pyright` as this is consistent with https://github.com/python/typeshed/blob/9e506eb5e8fc2823db8c60ad561b1145ff114947/stdlib/typing.pyi#L690 + + +@runtime_checkable +class SchemaLike(Generic[_JSON_VT_co], Protocol): # type: ignore[misc] + """ + Represents ``altair`` classes which *may* not derive ``SchemaBase``. + + Attributes + ---------- + _schema + A single item JSON Schema using the `type`_ keyword. + + Notes + ----- + Should be kept tightly defined to the **minimum** requirements for: + - Converting into a form that can be validated by `jsonschema`_. + - Avoiding calling ``.to_dict()`` on a class external to ``altair``. + - ``_schema`` is more accurately described as a ``ClassVar`` + - See `discussion`_ for blocking issue. + + .. _jsonschema: + https://github.com/python-jsonschema/jsonschema + .. _type: + https://json-schema.org/understanding-json-schema/reference/type + .. _discussion: + https://github.com/python/typing/discussions/1424 + """ + + _schema: _TypeMap[_JSON_VT_co] + + def to_dict(self, *args, **kwds) -> Any: ... + + +@runtime_checkable +class ConditionLike(SchemaLike[Literal["object"]], Protocol): + """ + Represents the wrapped state of a conditional encoding or property. + + Attributes + ---------- + condition + One or more (predicate, statement) pairs which each form a condition. + + Notes + ----- + - Can be extended with additional conditions. + - *Does not* define a default value, but can be finalized with one. + """ + + condition: Any + _schema: _TypeMap[Literal["object"]] = {"type": "object"} + + class UndefinedType: """A singleton object for marking undefined parameters."""