Skip to content

Commit

Permalink
Type hints: Infer types from vegalite schema for autogenerated code (#…
Browse files Browse the repository at this point in the history
…3208)

* Include schema files in mypy type checking

* Use actual names of Python types in docstrings

* Simplify medium_description by removing unused code paths

* Use Python type hint syntax instead of pseudo jsonschema syntax in docstrings

* First version of autoinferred type hints

* Various improvements

* Add type hints to mixins.py

* Bug fix: Add missing overload operators

* Add type hints to arguments of overload signatures where the signature has multiple arguments with Undefined default value

* Add type hint for Parameter

* Deduplicate type hints

* Minor renaming and reordering of type hints

* Recursively go through types to get all for type hints as well as docstrings

* Instead of writing a List of Unions, do it the other way around and do not allow for combinations of different types in lists. This deals with the invariance of lists

* Use Sequence instead of List

* Simplify code

* Remove unused ignore statement. Seems like mypy can now correctly detect that the overload statements are correct

* Switch to tooltip specification which conforms to type hints

* Ignore mypy errors when it cannot detect the existence of the .copy method

* Explicitly pass kwarg arguments in these examples

* Fix schema for 'shorthand' as it also accepts an array of strings

* Revert "Switch to tooltip specification which conforms to type hints"

This reverts commit 54439c8.

* Fix some more mypy errors

* Ignore mypy overload errors

* Black and ruff

* Add type hints to channel mixins

* Simplify docstrings by removing Union so they look better in docs. Remove type hints from signatures in docs as they are already shown in docstrings

* Fix merge conflicts and rerun schema generation

* Remove redundat capability with VL v2

* Factor out adding shorthand into a separate function. Fix missing Field class definitions in docstring where previously only shorthand was shown

* Add missing shorthand to field definitions in core.py

* Add shorthand str and list[str] type hints to signature of encode method

* Fix various mypy errors

* Fix mypy error for data argument. Add RepeatRef as a type hint for shorthand

* Add _ParameterProtocol wherever ParameterExtent is accepted

* Switch to ruff as code formatter

* Ruff fix

* Move type ignore comment which was shifted by Ruff

* Try to fix mypy issue in pipeline which does not appear locally

* Change type signature of encode method to only include the types which users expect

* Only show type hints for datum and value if they are accepted

* Fix trailing whitespace

* Switch back to just 'dict'

* Add type hint for list to encoding channels which support it

* Rename _ParameterProtocol to _Parameter

* Fix ruff error

* Add parameters incl. type hints to docstrings. Without descriptions for now

* Add descriptions

* Change import statement for expr.core to help type checkers
  • Loading branch information
binste authored Nov 22, 2023
1 parent bf52ed1 commit 4a1564b
Show file tree
Hide file tree
Showing 16 changed files with 134,472 additions and 13,028 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ hatch run test
```


This also runs the [`black`](https://black.readthedocs.io/) code formatter, [`ruff`](https://ruff.rs/) linter and [`mypy`](https://mypy-lang.org/) as type checker.
This also runs the [`ruff`](https://ruff.rs/) linter and formatter as well as [`mypy`](https://mypy-lang.org/) as type checker.


Study the output of any failed tests and try to fix the issues
Expand Down
2 changes: 1 addition & 1 deletion altair/utils/schemapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
else:
from typing_extensions import Self

_TSchemaBase = TypeVar("_TSchemaBase", bound="SchemaBase")
_TSchemaBase = TypeVar("_TSchemaBase", bound=Type["SchemaBase"])

ValidationErrorList = List[jsonschema.exceptions.ValidationError]
GroupedValidationErrors = Dict[str, ValidationErrorList]
Expand Down
81 changes: 51 additions & 30 deletions altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from .data import data_transformers
from ... import utils, expr
from ...expr import core as _expr_core
from .display import renderers, VEGALITE_VERSION, VEGAEMBED_VERSION, VEGA_VERSION
from .theme import themes
from .compiler import vegalite_compilers
Expand Down Expand Up @@ -186,9 +187,12 @@ def _get_channels_mapping() -> TypingDict[TypingType[core.SchemaBase], str]:

# -------------------------------------------------------------------------
# Tools for working with parameters
class Parameter(expr.core.OperatorMixin, object):
class Parameter(_expr_core.OperatorMixin):
"""A Parameter object"""

# NOTE: If you change this class, make sure that the protocol in
# altair/vegalite/v5/schema/core.py is updated accordingly if needed.

_counter: int = 0

@classmethod
Expand Down Expand Up @@ -238,23 +242,23 @@ def __invert__(self):
if self.param_type == "selection":
return SelectionPredicateComposition({"not": {"param": self.name}})
else:
return expr.core.OperatorMixin.__invert__(self)
return _expr_core.OperatorMixin.__invert__(self)

def __and__(self, other):
if self.param_type == "selection":
if isinstance(other, Parameter):
other = {"param": other.name}
return SelectionPredicateComposition({"and": [{"param": self.name}, other]})
else:
return expr.core.OperatorMixin.__and__(self, other)
return _expr_core.OperatorMixin.__and__(self, other)

def __or__(self, other):
if self.param_type == "selection":
if isinstance(other, Parameter):
other = {"param": other.name}
return SelectionPredicateComposition({"or": [{"param": self.name}, other]})
else:
return expr.core.OperatorMixin.__or__(self, other)
return _expr_core.OperatorMixin.__or__(self, other)

def __repr__(self) -> str:
return "Parameter({0!r}, {1})".format(self.name, self.param)
Expand All @@ -267,20 +271,20 @@ def _from_expr(self, expr) -> "ParameterExpression":

def __getattr__(
self, field_name: str
) -> Union[expr.core.GetAttrExpression, "SelectionExpression"]:
) -> Union[_expr_core.GetAttrExpression, "SelectionExpression"]:
if field_name.startswith("__") and field_name.endswith("__"):
raise AttributeError(field_name)
_attrexpr = expr.core.GetAttrExpression(self.name, field_name)
_attrexpr = _expr_core.GetAttrExpression(self.name, field_name)
# If self is a SelectionParameter and field_name is in its
# fields or encodings list, then we want to return an expression.
if check_fields_and_encodings(self, field_name):
return SelectionExpression(_attrexpr)
return expr.core.GetAttrExpression(self.name, field_name)
return _expr_core.GetAttrExpression(self.name, field_name)

# TODO: Are there any special cases to consider for __getitem__?
# This was copied from v4.
def __getitem__(self, field_name: str) -> expr.core.GetItemExpression:
return expr.core.GetItemExpression(self.name, field_name)
def __getitem__(self, field_name: str) -> _expr_core.GetItemExpression:
return _expr_core.GetItemExpression(self.name, field_name)


# Enables use of ~, &, | with compositions of selection objects.
Expand All @@ -295,7 +299,7 @@ def __or__(self, other):
return SelectionPredicateComposition({"or": [self.to_dict(), other.to_dict()]})


class ParameterExpression(expr.core.OperatorMixin, object):
class ParameterExpression(_expr_core.OperatorMixin, object):
def __init__(self, expr) -> None:
self.expr = expr

Expand All @@ -309,7 +313,7 @@ def _from_expr(self, expr) -> "ParameterExpression":
return ParameterExpression(expr=expr)


class SelectionExpression(expr.core.OperatorMixin, object):
class SelectionExpression(_expr_core.OperatorMixin, object):
def __init__(self, expr) -> None:
self.expr = expr

Expand Down Expand Up @@ -346,9 +350,9 @@ def value(value, **kwargs) -> dict:
def param(
name: Optional[str] = None,
value: Union[Any, UndefinedType] = Undefined,
bind: Union[core.Binding, str, UndefinedType] = Undefined,
bind: Union[core.Binding, UndefinedType] = Undefined,
empty: Union[bool, UndefinedType] = Undefined,
expr: Union[str, core.Expr, expr.core.Expression, UndefinedType] = Undefined,
expr: Union[str, core.Expr, _expr_core.Expression, UndefinedType] = Undefined,
**kwds,
) -> Parameter:
"""Create a named parameter.
Expand All @@ -365,7 +369,7 @@ def param(
value : any (optional)
The default value of the parameter. If not specified, the parameter
will be created without a default value.
bind : :class:`Binding`, str (optional)
bind : :class:`Binding` (optional)
Binds the parameter to an external input element such as a slider,
selection list or radio button group.
empty : boolean (optional)
Expand Down Expand Up @@ -421,9 +425,14 @@ def param(
# If both 'value' and 'init' are set, we ignore 'init'.
kwds.pop("init")

# ignore[arg-type] comment is needed because we can also pass _expr_core.Expression
if "select" not in kwds:
parameter.param = core.VariableParameter(
name=parameter.name, bind=bind, value=value, expr=expr, **kwds
name=parameter.name,
bind=bind,
value=value,
expr=expr, # type: ignore[arg-type]
**kwds,
)
parameter.param_type = "variable"
elif "views" in kwds:
Expand Down Expand Up @@ -503,7 +512,7 @@ def selection_interval(
value: Union[Any, UndefinedType] = Undefined,
bind: Union[core.Binding, str, UndefinedType] = Undefined,
empty: Union[bool, UndefinedType] = Undefined,
expr: Union[str, core.Expr, expr.core.Expression, UndefinedType] = Undefined,
expr: Union[str, core.Expr, _expr_core.Expression, UndefinedType] = Undefined,
encodings: Union[List[str], UndefinedType] = Undefined,
on: Union[str, UndefinedType] = Undefined,
clear: Union[str, bool, UndefinedType] = Undefined,
Expand Down Expand Up @@ -807,7 +816,7 @@ def condition(
test_predicates = (str, expr.Expression, core.PredicateComposition)

condition: TypingDict[
str, Union[bool, str, expr.core.Expression, core.PredicateComposition]
str, Union[bool, str, _expr_core.Expression, core.PredicateComposition]
]
if isinstance(predicate, Parameter):
if (
Expand Down Expand Up @@ -1466,7 +1475,9 @@ def project(
spacing=spacing,
tilt=tilt,
translate=translate,
type=type,
# Ignore as we type here `type` as a str but in core.Projection
# it's a Literal with all options
type=type, # type: ignore[arg-type]
**kwds,
)
return self.properties(projection=projection)
Expand Down Expand Up @@ -1627,9 +1638,9 @@ def transform_calculate(
self,
as_: Union[str, core.FieldName, UndefinedType] = Undefined,
calculate: Union[
str, core.Expr, expr.core.Expression, UndefinedType
str, core.Expr, _expr_core.Expression, UndefinedType
] = Undefined,
**kwargs: Union[str, core.Expr, expr.core.Expression],
**kwargs: Union[str, core.Expr, _expr_core.Expression],
) -> Self:
"""
Add a :class:`CalculateTransform` to the schema.
Expand Down Expand Up @@ -1690,10 +1701,10 @@ def transform_calculate(
)
if as_ is not Undefined or calculate is not Undefined:
dct = {"as": as_, "calculate": calculate}
self = self._add_transform(core.CalculateTransform(**dct))
self = self._add_transform(core.CalculateTransform(**dct)) # type: ignore[arg-type]
for as_, calculate in kwargs.items():
dct = {"as": as_, "calculate": calculate}
self = self._add_transform(core.CalculateTransform(**dct))
self = self._add_transform(core.CalculateTransform(**dct)) # type: ignore[arg-type]
return self

def transform_density(
Expand Down Expand Up @@ -1922,7 +1933,7 @@ def transform_filter(
filter: Union[
str,
core.Expr,
expr.core.Expression,
_expr_core.Expression,
core.Predicate,
Parameter,
core.PredicateComposition,
Expand Down Expand Up @@ -1956,7 +1967,7 @@ def transform_filter(
elif isinstance(filter.empty, bool):
new_filter["empty"] = filter.empty
filter = new_filter # type: ignore[assignment]
return self._add_transform(core.FilterTransform(filter=filter, **kwargs))
return self._add_transform(core.FilterTransform(filter=filter, **kwargs)) # type: ignore[arg-type]

def transform_flatten(
self,
Expand Down Expand Up @@ -2158,7 +2169,13 @@ def transform_pivot(
"""
return self._add_transform(
core.PivotTransform(
pivot=pivot, value=value, groupby=groupby, limit=limit, op=op
# Ignore as we type here `op` as a str but in core.PivotTransform
# it's a Literal with all options
pivot=pivot,
value=value,
groupby=groupby,
limit=limit,
op=op, # type: ignore[arg-type]
)
)

Expand Down Expand Up @@ -2408,7 +2425,7 @@ def transform_timeunit(
)
if as_ is not Undefined:
dct = {"as": as_, "timeUnit": timeUnit, "field": field}
self = self._add_transform(core.TimeUnitTransform(**dct))
self = self._add_transform(core.TimeUnitTransform(**dct)) # type: ignore[arg-type]
for as_, shorthand in kwargs.items():
dct = utils.parse_shorthand(
shorthand,
Expand All @@ -2420,7 +2437,7 @@ def transform_timeunit(
dct["as"] = as_
if "timeUnit" not in dct:
raise ValueError("'{}' must include a valid timeUnit".format(shorthand))
self = self._add_transform(core.TimeUnitTransform(**dct))
self = self._add_transform(core.TimeUnitTransform(**dct)) # type: ignore[arg-type]
return self

def transform_window(
Expand Down Expand Up @@ -2516,7 +2533,8 @@ def transform_window(
)
)
assert not isinstance(window, UndefinedType) # For mypy
window.append(core.WindowFieldDef(**kwds))
# Ignore as core.WindowFieldDef has a Literal type hint with all options
window.append(core.WindowFieldDef(**kwds)) # type: ignore[arg-type]

return self._add_transform(
core.WindowTransform(
Expand Down Expand Up @@ -2697,7 +2715,7 @@ def resolve_scale(self, *args, **kwargs) -> Self:


class _EncodingMixin:
@utils.use_signature(core.FacetedEncoding)
@utils.use_signature(channels._encode_signature)
def encode(self, *args, **kwargs) -> Self:
# Convert args to kwargs based on their types.
kwargs = utils.infer_encoding_types(args, kwargs, channels)
Expand Down Expand Up @@ -2853,7 +2871,10 @@ def __init__(
**kwargs,
) -> None:
super(Chart, self).__init__(
data=data,
# Data type hints won't match with what TopLevelUnitSpec expects
# as there is some data processing happening when converting to
# a VL spec
data=data, # type: ignore[arg-type]
encoding=encoding,
mark=mark,
width=width,
Expand Down
9 changes: 6 additions & 3 deletions altair/vegalite/v5/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# ruff: noqa

from .core import *
from .channels import *
SCHEMA_VERSION = 'v5.15.1'
SCHEMA_URL = 'https://vega.github.io/schema/vega-lite/v5.15.1.json'
from .channels import * # type: ignore[assignment]

SCHEMA_VERSION = "v5.15.1"

SCHEMA_URL = "https://vega.github.io/schema/vega-lite/v5.15.1.json"
Loading

0 comments on commit 4a1564b

Please sign in to comment.