From 127a4e9dd315329bf86c265ef268b0ee3e4324f0 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Wed, 9 Aug 2023 19:20:31 +0200 Subject: [PATCH] First batch of type hints --- altair/vegalite/v5/api.py | 113 +++++++++++++++++------------- altair/vegalite/v5/schema/core.py | 2 +- 2 files changed, 64 insertions(+), 51 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 70680984a6..75b2aa54d6 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -8,13 +8,13 @@ from toolz.curried import pipe as _pipe import itertools import sys -from typing import cast, List, Optional, Any, Iterable +from typing import cast, List, Optional, Any, Iterable, Union, Type, Dict, Literal # Have to rename it here as else it overlaps with schema.core.Type from typing import Type as TypingType from typing import Dict as TypingDict -from .schema import core, channels, mixins, Undefined, SCHEMA_URL +from .schema import core, channels, mixins, Undefined, SCHEMA_URL, UndefinedType from .data import data_transformers from ... import utils, expr @@ -35,13 +35,13 @@ # ------------------------------------------------------------------------ # Data Utilities -def _dataset_name(values): +def _dataset_name(values: Union[dict, list, core.InlineDataset]) -> str: """Generate a unique hash of the data Parameters ---------- - values : list or dict - A list/dict representation of data values. + values : list, dict, core.InlineDataset + A representation of data values. Returns ------- @@ -136,7 +136,7 @@ class LookupData(core.LookupData): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def to_dict(self, *args, **kwargs): + def to_dict(self, *args, **kwargs) -> dict: """Convert the chart to a dictionary suitable for JSON export.""" copy = self.copy(deep=False) copy.data = _prepare_data(copy.data, kwargs.get("context")) @@ -150,7 +150,7 @@ class FacetMapping(core.FacetMapping): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def to_dict(self, *args, **kwargs): + def to_dict(self, *args, **kwargs) -> dict: copy = self.copy(deep=False) context = kwargs.get("context", {}) data = context.get("data", None) @@ -172,8 +172,8 @@ def to_dict(self, *args, **kwargs): TOPLEVEL_ONLY_KEYS = {"background", "config", "autosize", "padding", "$schema"} -def _get_channels_mapping(): - mapping = {} +def _get_channels_mapping() -> Dict[Type[core.SchemaBase], str]: + mapping: Dict[Type[core.SchemaBase], str] = {} for attr in dir(channels): cls = getattr(channels, attr) if isinstance(cls, type) and issubclass(cls, core.SchemaBase): @@ -189,11 +189,11 @@ class Parameter(expr.core.OperatorMixin, object): _counter = 0 @classmethod - def _get_name(cls): + def _get_name(cls) -> str: cls._counter += 1 return f"param_{cls._counter}" - def __init__(self, name): + def __init__(self, name: Optional[str]) -> None: if name is None: name = self._get_name() self.name = name @@ -201,11 +201,11 @@ def __init__(self, name): @utils.deprecation.deprecated( message="'ref' is deprecated. No need to call '.ref()' anymore." ) - def ref(self): + def ref(self) -> dict: "'ref' is deprecated. No need to call '.ref()' anymore." return self.to_dict() - def to_dict(self): + def to_dict(self) -> dict: if self.param_type == "variable": return {"expr": self.name} elif self.param_type == "selection": @@ -214,6 +214,8 @@ def to_dict(self): if hasattr(self.name, "to_dict") else self.name } + else: + raise ValueError(f"Unrecognized parameter type: {self.param_type}") def __invert__(self): if self.param_type == "selection": @@ -237,16 +239,18 @@ def __or__(self, other): else: return expr.core.OperatorMixin.__or__(self, other) - def __repr__(self): + def __repr__(self) -> str: return "Parameter({0!r}, {1})".format(self.name, self.param) - def _to_expr(self): + def _to_expr(self) -> str: return self.name - def _from_expr(self, expr): + def _from_expr(self, expr) -> "ParameterExpression": return ParameterExpression(expr=expr) - def __getattr__(self, field_name): + def __getattr__( + self, field_name: str + ) -> Union["SelectionExpression", expr.core.GetAttrExpression]: if field_name.startswith("__") and field_name.endswith("__"): raise AttributeError(field_name) _attrexpr = expr.core.GetAttrExpression(self.name, field_name) @@ -258,7 +262,7 @@ def __getattr__(self, field_name): # TODO: Are there any special cases to consider for __getitem__? # This was copied from v4. - def __getitem__(self, field_name): + def __getitem__(self, field_name: str) -> expr.core.GetItemExpression: return expr.core.GetItemExpression(self.name, field_name) @@ -275,34 +279,34 @@ def __or__(self, other): class ParameterExpression(expr.core.OperatorMixin, object): - def __init__(self, expr): + def __init__(self, expr) -> None: self.expr = expr - def to_dict(self): + def to_dict(self) -> Dict[str, str]: return {"expr": repr(self.expr)} - def _to_expr(self): + def _to_expr(self) -> str: return repr(self.expr) - def _from_expr(self, expr): + def _from_expr(self, expr) -> "ParameterExpression": return ParameterExpression(expr=expr) class SelectionExpression(expr.core.OperatorMixin, object): - def __init__(self, expr): + def __init__(self, expr) -> None: self.expr = expr - def to_dict(self): + def to_dict(self) -> Dict[str, str]: return {"expr": repr(self.expr)} - def _to_expr(self): + def _to_expr(self) -> str: return repr(self.expr) - def _from_expr(self, expr): + def _from_expr(self, expr) -> "SelectionExpression": return SelectionExpression(expr=expr) -def check_fields_and_encodings(parameter, field_name): +def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool: for prop in ["fields", "encodings"]: try: if field_name in getattr(parameter.param.select, prop): @@ -317,20 +321,24 @@ def check_fields_and_encodings(parameter, field_name): # Top-Level Functions -def value(value, **kwargs): +def value(value, **kwargs) -> dict: """Specify a value for use in an encoding""" return dict(value=value, **kwargs) def param( - name=None, - value=Undefined, - bind=Undefined, - empty=Undefined, - expr=Undefined, + name: Optional[str] = None, + value: Union[UndefinedType, Any] = Undefined, + bind: Union[UndefinedType, core.Binding] = Undefined, + empty: Union[UndefinedType, bool] = Undefined, + expr: Union[UndefinedType, core.Expr] = Undefined, **kwds, -): - """Create a named parameter. See https://altair-viz.github.io/user_guide/interactions.html for examples. Although both variable parameters and selection parameters can be created using this 'param' function, to create a selection parameter, it is recommended to use either 'selection_point' or 'selection_interval' instead. +) -> Parameter: + """Create a named parameter. + See https://altair-viz.github.io/user_guide/interactions.html for examples. + Although both variable parameters and selection parameters can be created using + this 'param' function, to create a selection parameter, it is recommended to use + either 'selection_point' or 'selection_interval' instead. Parameters ---------- @@ -415,7 +423,9 @@ def param( return parameter -def _selection(type=Undefined, **kwds): +def _selection( + type: Union[UndefinedType, Literal["interval", "point"]] = Undefined, **kwds +) -> Parameter: # We separate out the parameter keywords from the selection keywords param_kwds = {} @@ -423,6 +433,7 @@ def _selection(type=Undefined, **kwds): if kwd in kwds: param_kwds[kwd] = kwds.pop(kwd) + select: Union[core.IntervalSelectionConfig, core.PointSelectionConfig] if type == "interval": select = core.IntervalSelectionConfig(type=type, **kwds) elif type == "point": @@ -445,7 +456,9 @@ def _selection(type=Undefined, **kwds): message="""'selection' is deprecated. Use 'selection_point()' or 'selection_interval()' instead; these functions also include more helpful docstrings.""" ) -def selection(type=Undefined, **kwds): +def selection( + type: Union[UndefinedType, Literal["interval", "point"]] = Undefined, **kwds +) -> Parameter: """ Users are recommended to use either 'selection_point' or 'selection_interval' instead, depending on the type of parameter they want to create. @@ -469,20 +482,20 @@ def selection(type=Undefined, **kwds): def selection_interval( - name=None, - value=Undefined, - bind=Undefined, - empty=Undefined, - expr=Undefined, - encodings=Undefined, - on=Undefined, - clear=Undefined, - resolve=Undefined, - mark=Undefined, - translate=Undefined, - zoom=Undefined, + name: Optional[str] = None, + value: Union[UndefinedType, Any] = Undefined, + bind: Union[UndefinedType, core.Binding] = Undefined, + empty: Union[UndefinedType, bool] = Undefined, + expr: Union[UndefinedType, core.Expr] = Undefined, + encodings: Union[UndefinedType, List[str]] = Undefined, + on: Union[UndefinedType, str] = Undefined, + clear: Union[UndefinedType, str, bool] = Undefined, + resolve: Union[UndefinedType, Literal["global", "union", "intersect"]] = Undefined, + mark: Union[UndefinedType, core.Mark] = Undefined, + translate: Union[UndefinedType, str, bool] = Undefined, + zoom: Union[UndefinedType, str, bool] = Undefined, **kwds, -): +) -> Parameter: """Create an interval selection parameter. Selection parameters define data queries that are driven by direct manipulation from user input (e.g., mouse clicks or drags). Interval selection parameters are used to select a continuous range of data values on drag, whereas point selection parameters (`selection_point`) are used to select multiple discrete data values.) Parameters diff --git a/altair/vegalite/v5/schema/core.py b/altair/vegalite/v5/schema/core.py index ad1e5afd50..d2f4887e6b 100644 --- a/altair/vegalite/v5/schema/core.py +++ b/altair/vegalite/v5/schema/core.py @@ -1,7 +1,7 @@ # The contents of this file are automatically written by # tools/generate_schema_wrapper.py. Do not modify directly. -from altair.utils.schemapi import SchemaBase, Undefined, _subclasses +from altair.utils.schemapi import SchemaBase, Undefined, _subclasses, UndefinedType import pkgutil import json