From b37150222db5d70c31d865b18e5257a5c646c97e Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Fri, 29 Sep 2023 19:40:30 +0200 Subject: [PATCH 1/4] First type hints for code generation --- pyproject.toml | 1 + tools/generate_api_docs.py | 29 ++++---- tools/generate_schema_wrapper.py | 68 ++++++++++-------- tools/schemapi/utils.py | 114 ++++++++++++++++++------------- tools/update_init_file.py | 8 +-- 5 files changed, 128 insertions(+), 92 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4def443b6..f60d2d9c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -230,6 +230,7 @@ module = [ "pandas.lib.*", "nbformat.*", "ipykernel.*", + "m2r.*", ] ignore_missing_imports = true diff --git a/tools/generate_api_docs.py b/tools/generate_api_docs.py index 20335c116..46b997313 100644 --- a/tools/generate_api_docs.py +++ b/tools/generate_api_docs.py @@ -5,15 +5,17 @@ import sys import types from os.path import abspath, dirname, join +from typing import Final, Optional, Iterator, List +from types import ModuleType # Import Altair from head ROOT_DIR = abspath(join(dirname(__file__), "..")) sys.path.insert(0, ROOT_DIR) import altair as alt # noqa: E402 -API_FILENAME = join(ROOT_DIR, "doc", "user_guide", "api.rst") +API_FILENAME: Final = join(ROOT_DIR, "doc", "user_guide", "api.rst") -API_TEMPLATE = """\ +API_TEMPLATE: Final = """\ .. _api: API Reference @@ -68,8 +70,11 @@ def iter_objects( - mod, ignore_private=True, restrict_to_type=None, restrict_to_subclass=None -): + mod: ModuleType, + ignore_private: bool = True, + restrict_to_type: Optional[type] = None, + restrict_to_subclass: Optional[type] = None, +) -> Iterator[str]: for name in dir(mod): obj = getattr(mod, name) if ignore_private: @@ -84,26 +89,26 @@ def iter_objects( yield name -def toplevel_charts(): - return sorted(iter_objects(alt.api, restrict_to_subclass=alt.TopLevelMixin)) +def toplevel_charts() -> List[str]: + return sorted(iter_objects(alt.api, restrict_to_subclass=alt.TopLevelMixin)) # type: ignore[attr-defined] -def encoding_wrappers(): +def encoding_wrappers() -> List[str]: return sorted(iter_objects(alt.channels, restrict_to_subclass=alt.SchemaBase)) -def api_functions(): +def api_functions() -> List[str]: # Exclude typing.cast altair_api_functions = [ obj_name - for obj_name in iter_objects(alt.api, restrict_to_type=types.FunctionType) + for obj_name in iter_objects(alt.api, restrict_to_type=types.FunctionType) # type: ignore[attr-defined] if obj_name != "cast" ] return sorted(altair_api_functions) -def lowlevel_wrappers(): - objects = sorted(iter_objects(alt.schema.core, restrict_to_subclass=alt.SchemaBase)) +def lowlevel_wrappers() -> List[str]: + objects = sorted(iter_objects(alt.schema.core, restrict_to_subclass=alt.SchemaBase)) # type: ignore[attr-defined] # The names of these two classes are also used for classes in alt.channels. Due to # how imports are set up, these channel classes overwrite the two low-level classes # in the top-level Altair namespace. Therefore, they cannot be imported as e.g. @@ -113,7 +118,7 @@ def lowlevel_wrappers(): return objects -def write_api_file(): +def write_api_file() -> None: print("Updating API docs\n ->{}".format(API_FILENAME)) sep = "\n " with open(API_FILENAME, "w") as f: diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index 600ebdfbc..6325ebacd 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -6,6 +6,7 @@ import json import re from os.path import abspath, join, dirname +from typing import Final, Optional, List, Dict, Tuple, Literal, Union, Type import textwrap from urllib import request @@ -29,21 +30,21 @@ ) # Map of version name to github branch name. -SCHEMA_VERSION = { +SCHEMA_VERSION: Final = { "vega-lite": {"v5": "v5.14.1"}, } reLink = re.compile(r"(?<=\[)([^\]]+)(?=\]\([^\)]+\))", re.M) reSpecial = re.compile(r"[*_]{2,3}|`", re.M) -HEADER = """\ +HEADER: Final = """\ # The contents of this file are automatically written by # tools/generate_schema_wrapper.py. Do not modify directly. """ -SCHEMA_URL_TEMPLATE = "https://vega.github.io/schema/" "{library}/{version}.json" +SCHEMA_URL_TEMPLATE: Final = "https://vega.github.io/schema/" "{library}/{version}.json" -BASE_SCHEMA = """ +BASE_SCHEMA: Final = """ class {basename}(SchemaBase): _rootschema = load_schema() @classmethod @@ -51,7 +52,7 @@ def _default_wrapper_classes(cls): return _subclasses({basename}) """ -LOAD_SCHEMA = ''' +LOAD_SCHEMA: Final = ''' import pkgutil import json @@ -61,7 +62,7 @@ def load_schema(): ''' -CHANNEL_MIXINS = """ +CHANNEL_MIXINS: Final = """ class FieldChannelMixin: def to_dict(self, validate=True, ignore=(), context=None): context = context or {} @@ -146,7 +147,7 @@ def to_dict(self, validate=True, ignore=(), context=None): context=context) """ -MARK_METHOD = ''' +MARK_METHOD: Final = ''' def mark_{mark}({def_arglist}) -> Self: """Set the chart's mark to '{mark}' (see :class:`{mark_def}`) """ @@ -159,7 +160,7 @@ def mark_{mark}({def_arglist}) -> Self: return copy ''' -CONFIG_METHOD = """ +CONFIG_METHOD: Final = """ @use_signature(core.{classname}) def {method}(self, *args, **kwargs) -> Self: copy = self.copy(deep=False) @@ -167,7 +168,7 @@ def {method}(self, *args, **kwargs) -> Self: return copy """ -CONFIG_PROP_METHOD = """ +CONFIG_PROP_METHOD: Final = """ @use_signature(core.{classname}) def configure_{prop}(self, *args, **kwargs) -> Self: copy = self.copy(deep=['config']) @@ -189,7 +190,7 @@ class {classname}({basename}): ''' ) - def _process_description(self, description): + def _process_description(self, description: str): description = "".join( [ reSpecial.sub("", d) if i % 2 else d @@ -257,16 +258,18 @@ class {classname}(DatumChannelMixin, core.{basename}): ) -def schema_class(*args, **kwargs): +def schema_class(*args, **kwargs) -> str: return SchemaGenerator(*args, **kwargs).schema_class() -def schema_url(library, version): +def schema_url(library: str, version: str) -> str: version = SCHEMA_VERSION[library][version] return SCHEMA_URL_TEMPLATE.format(library=library, version=version) -def download_schemafile(library, version, schemapath, skip_download=False): +def download_schemafile( + library: str, version: str, schemapath: str, skip_download: bool = False +) -> str: url = schema_url(library, version) if not os.path.exists(schemapath): os.makedirs(schemapath) @@ -278,7 +281,7 @@ def download_schemafile(library, version, schemapath, skip_download=False): return filename -def copy_schemapi_util(): +def copy_schemapi_util() -> None: """ Copy the schemapi utility into altair/utils/ and its test file to tests/utils/ """ @@ -295,7 +298,7 @@ def copy_schemapi_util(): dest.writelines(source.readlines()) -def recursive_dict_update(schema, root, def_dict): +def recursive_dict_update(schema: dict, root: dict, def_dict: dict) -> None: if "$ref" in schema: next_schema = resolve_references(schema, root) if "properties" in next_schema: @@ -311,7 +314,7 @@ def recursive_dict_update(schema, root, def_dict): recursive_dict_update(sub_schema, root, def_dict) -def get_field_datum_value_defs(propschema, root): +def get_field_datum_value_defs(propschema: SchemaInfo, root: dict) -> dict: def_dict = {k: None for k in ("field", "datum", "value")} schema = propschema.schema if propschema.is_reference() and "properties" in schema: @@ -325,7 +328,7 @@ def get_field_datum_value_defs(propschema, root): return {i: j for i, j in def_dict.items() if j} -def toposort(graph): +def toposort(graph: Dict[str, List[str]]) -> List[str]: """Topological sort of a directed acyclic graph. Parameters @@ -339,8 +342,10 @@ def toposort(graph): order : list topological order of input graph. """ - stack = [] - visited = {} + # Once we drop support for Python 3.8, this can potentially be replaced + # with graphlib.TopologicalSorter from the standard library. + stack: List[str] = [] + visited: Dict[str, Literal[True]] = {} def visit(nodes): for node in sorted(nodes, reverse=True): @@ -353,7 +358,7 @@ def visit(nodes): return stack -def generate_vegalite_schema_wrapper(schema_file): +def generate_vegalite_schema_wrapper(schema_file: str) -> str: """Generate a schema wrapper at the given path.""" # TODO: generate simple tests for each wrapper basename = "VegaLiteSchema" @@ -361,7 +366,7 @@ def generate_vegalite_schema_wrapper(schema_file): with open(schema_file, encoding="utf8") as f: rootschema = json.load(f) - definitions = {} + definitions: Dict[str, SchemaGenerator] = {} for name in rootschema["definitions"]: defschema = {"$ref": "#/definitions/" + name} @@ -376,7 +381,7 @@ def generate_vegalite_schema_wrapper(schema_file): rootschemarepr=CodeSnippet("{}._rootschema".format(basename)), ) - graph = {} + graph: Dict[str, List[str]] = {} for name, schema in definitions.items(): graph[name] = [] @@ -411,7 +416,9 @@ def generate_vegalite_schema_wrapper(schema_file): return "\n".join(contents) -def generate_vegalite_channel_wrappers(schemafile, version, imports=None): +def generate_vegalite_channel_wrappers( + schemafile: str, version: str, imports: Optional[List[str]] = None +) -> str: # TODO: generate __all__ for top of file with open(schemafile, encoding="utf8") as f: schema = json.load(f) @@ -449,6 +456,11 @@ def generate_vegalite_channel_wrappers(schemafile, version, imports=None): defschema = {"$ref": definition} + Generator: Union[ + Type[FieldSchemaGenerator], + Type[DatumSchemaGenerator], + Type[ValueSchemaGenerator], + ] if encoding_spec == "field": Generator = FieldSchemaGenerator nodefault = [] @@ -485,7 +497,9 @@ def generate_vegalite_channel_wrappers(schemafile, version, imports=None): return "\n".join(contents) -def generate_vegalite_mark_mixin(schemafile, markdefs): +def generate_vegalite_mark_mixin( + schemafile: str, markdefs: Dict[str, str] +) -> Tuple[List[str], str]: with open(schemafile, encoding="utf8") as f: schema = json.load(f) @@ -535,7 +549,7 @@ def generate_vegalite_mark_mixin(schemafile, markdefs): return imports, "\n".join(code) -def generate_vegalite_config_mixin(schemafile): +def generate_vegalite_config_mixin(schemafile: str) -> Tuple[List[str], str]: imports = ["from . import core", "from altair.utils import use_signature"] class_name = "ConfigMethodMixin" @@ -561,7 +575,7 @@ def generate_vegalite_config_mixin(schemafile): return imports, "\n".join(code) -def vegalite_main(skip_download=False): +def vegalite_main(skip_download: bool = False) -> None: library = "vega-lite" for version in SCHEMA_VERSION[library]: @@ -631,7 +645,7 @@ def vegalite_main(skip_download=False): f.write(config_mixin) -def main(): +def main() -> None: parser = argparse.ArgumentParser( prog="generate_schema_wrapper.py", description="Generate the Altair package." ) diff --git a/tools/schemapi/utils.py b/tools/schemapi/utils.py index e05cac91a..320d6d109 100644 --- a/tools/schemapi/utils.py +++ b/tools/schemapi/utils.py @@ -4,16 +4,20 @@ import re import textwrap import urllib +from typing import Final, Optional, List, Dict, Literal, Union from .schemapi import _resolve_references as resolve_references -EXCLUDE_KEYS = ("definitions", "title", "description", "$schema", "id") +EXCLUDE_KEYS: Final = ("definitions", "title", "description", "$schema", "id") def get_valid_identifier( - prop, replacement_character="", allow_unicode=False, url_decode=True -): + prop: str, + replacement_character: str = "", + allow_unicode: bool = False, + url_decode: bool = True, +) -> str: """Given a string property, generate a valid Python identifier Parameters @@ -70,7 +74,7 @@ def get_valid_identifier( return valid -def is_valid_identifier(var, allow_unicode=False): +def is_valid_identifier(var: str, allow_unicode: bool = False): """Return true if var contains a valid Python identifier Parameters @@ -88,15 +92,17 @@ def is_valid_identifier(var, allow_unicode=False): class SchemaProperties: """A wrapper for properties within a schema""" - def __init__(self, properties, schema, rootschema=None): + def __init__( + self, properties: dict, schema: dict, rootschema: Optional[dict] = None + ) -> None: self._properties = properties self._schema = schema self._rootschema = rootschema or schema - def __bool__(self): + def __bool__(self) -> bool: return bool(self._properties) - def __dir__(self): + def __dir__(self) -> List[str]: return list(self._properties.keys()) def __getattr__(self, attr): @@ -127,7 +133,9 @@ def values(self): class SchemaInfo: """A wrapper for inspecting a JSON schema""" - def __init__(self, schema, rootschema=None): + def __init__( + self, schema: dict, rootschema: Optional[dict] = None + ) -> None: if hasattr(schema, "_schema"): if hasattr(schema, "_rootschema"): schema, rootschema = schema._schema, schema._rootschema @@ -139,10 +147,10 @@ def __init__(self, schema, rootschema=None): self.rootschema = rootschema self.schema = resolve_references(schema, rootschema) - def child(self, schema): + def child(self, schema: dict) -> "SchemaInfo": return self.__class__(schema, rootschema=self.rootschema) - def __repr__(self): + def __repr__(self) -> str: keys = [] for key in sorted(self.schema.keys()): val = self.schema[key] @@ -157,21 +165,21 @@ def __repr__(self): return "SchemaInfo({\n " + "\n ".join(keys) + "\n})" @property - def title(self): + def title(self) -> str: if self.is_reference(): return get_valid_identifier(self.refname) else: return "" @property - def short_description(self): + def short_description(self) -> str: if self.title: # use RST syntax for generated sphinx docs return ":class:`{}`".format(self.title) else: return self.medium_description - _simple_types = { + _simple_types: Dict[str, str] = { "string": "string", "number": "float", "integer": "integer", @@ -182,7 +190,7 @@ def short_description(self): } @property - def medium_description(self): + def medium_description(self) -> str: if self.is_list(): return "[{0}]".format( ", ".join(self.child(s).short_description for s in self.schema) @@ -228,79 +236,79 @@ def medium_description(self): return "any" @property - def long_description(self): + def long_description(self) -> str: # TODO return "Long description including arguments and their types" @property - def properties(self): + def properties(self) -> SchemaProperties: return SchemaProperties( self.schema.get("properties", {}), self.schema, self.rootschema ) @property - def definitions(self): + def definitions(self) -> SchemaProperties: return SchemaProperties( self.schema.get("definitions", {}), self.schema, self.rootschema ) @property - def required(self): + def required(self) -> list: return self.schema.get("required", []) @property - def patternProperties(self): + def patternProperties(self) -> dict: return self.schema.get("patternProperties", {}) @property - def additionalProperties(self): + def additionalProperties(self) -> bool: return self.schema.get("additionalProperties", True) @property - def type(self): + def type(self) -> Optional[str]: return self.schema.get("type", None) @property - def anyOf(self): + def anyOf(self) -> list: return [self.child(s) for s in self.schema.get("anyOf", [])] @property - def oneOf(self): + def oneOf(self) -> list: return [self.child(s) for s in self.schema.get("oneOf", [])] @property - def allOf(self): + def allOf(self) -> list: return [self.child(s) for s in self.schema.get("allOf", [])] @property - def not_(self): + def not_(self) -> dict: return self.child(self.schema.get("not", {})) @property - def items(self): + def items(self) -> dict: return self.schema.get("items", {}) @property - def enum(self): + def enum(self) -> list: return self.schema.get("enum", []) @property - def refname(self): + def refname(self) -> str: return self.raw_schema.get("$ref", "#/").split("/")[-1] @property - def ref(self): + def ref(self) -> Optional[str]: return self.raw_schema.get("$ref", None) @property - def description(self): + def description(self) -> str: return self._get_description(include_sublevels=False) @property - def deep_description(self): + def deep_description(self) -> str: return self._get_description(include_sublevels=True) - def _get_description(self, include_sublevels: bool = False): + def _get_description(self, include_sublevels: bool = False) -> str: desc = self.raw_schema.get("description", self.schema.get("description", "")) if not desc and include_sublevels: for item in self.anyOf: @@ -316,34 +324,34 @@ def _get_description(self, include_sublevels: bool = False): desc = sub_desc return desc - def is_list(self): + def is_list(self) -> bool: return isinstance(self.schema, list) - def is_reference(self): + def is_reference(self) -> bool: return "$ref" in self.raw_schema - def is_enum(self): + def is_enum(self) -> bool: return "enum" in self.schema - def is_empty(self): + def is_empty(self) -> bool: return not (set(self.schema.keys()) - set(EXCLUDE_KEYS)) - def is_compound(self): + def is_compound(self) -> bool: return any(key in self.schema for key in ["anyOf", "allOf", "oneOf"]) - def is_anyOf(self): + def is_anyOf(self) -> bool: return "anyOf" in self.schema - def is_allOf(self): + def is_allOf(self) -> bool: return "allOf" in self.schema - def is_oneOf(self): + def is_oneOf(self) -> bool: return "oneOf" in self.schema - def is_not(self): + def is_not(self) -> bool: return "not" in self.schema - def is_object(self): + def is_object(self) -> bool: if self.type == "object": return True elif self.type is not None: @@ -358,19 +366,23 @@ def is_object(self): else: raise ValueError("Unclear whether schema.is_object() is True") - def is_value(self): + def is_value(self) -> bool: return not self.is_object() - def is_array(self): + def is_array(self) -> bool: return self.type == "array" - def schema_type(self): + def schema_type( + self, + ) -> Literal["empty", "anyOf", "oneOf", "allOf", "object", "array", "value"]: if self.is_empty(): return "empty" elif self.is_compound(): for key in ["anyOf", "oneOf", "allOf"]: if key in self.schema: return key + else: + raise ValueError("Unclear why schema.is_compound() is True") elif self.is_object(): return "object" elif self.is_array(): @@ -380,7 +392,7 @@ def schema_type(self): else: raise ValueError("Unknown type with keys {}".format(self.schema)) - def property_name_map(self): + def property_name_map(self) -> Dict[str, str]: """ Return a mapping of schema property names to valid Python attribute names @@ -391,7 +403,9 @@ def property_name_map(self): return {prop: val for prop, val in pairs if prop != val} -def indent_arglist(args, indent_level, width=100, lstrip=True): +def indent_arglist( + args: List[str], indent_level: int, width: int = 100, lstrip: bool = True +) -> str: """Indent an argument list for use in generated code""" wrapper = textwrap.TextWrapper( width=width, @@ -405,7 +419,9 @@ def indent_arglist(args, indent_level, width=100, lstrip=True): return wrapped -def indent_docstring(lines, indent_level, width=100, lstrip=True): +def indent_docstring( + lines: List[str], indent_level: int, width: int = 100, lstrip=True +) -> str: """Indent a docstring for use in generated code""" final_lines = [] @@ -467,7 +483,7 @@ def indent_docstring(lines, indent_level, width=100, lstrip=True): return wrapped -def fix_docstring_issues(docstring): +def fix_docstring_issues(docstring: str) -> str: # All lists should start with '*' followed by a whitespace. Fixes the ones # which either do not have a whitespace or/and start with '-' by first replacing # "-" with "*" and then adding a whitespace where necessary diff --git a/tools/update_init_file.py b/tools/update_init_file.py index d657965bf..a7a9a2cd8 100644 --- a/tools/update_init_file.py +++ b/tools/update_init_file.py @@ -15,15 +15,15 @@ else: from typing_extensions import Self -from typing import Literal +from typing import Literal, Final # Import Altair from head -ROOT_DIR = abspath(join(dirname(__file__), "..")) +ROOT_DIR: Final = abspath(join(dirname(__file__), "..")) sys.path.insert(0, ROOT_DIR) import altair as alt # noqa: E402 -def update__all__variable(): +def update__all__variable() -> None: """Updates the __all__ variable to all relevant attributes of top-level Altair. This is for example useful to hide deprecated attributes from code completion in Jupyter. @@ -65,7 +65,7 @@ def update__all__variable(): f.write(new_file_content) -def _is_relevant_attribute(attr_name): +def _is_relevant_attribute(attr_name: str) -> bool: attr = getattr(alt, attr_name) if ( getattr(attr, "_deprecated", False) is True From 5a1460283040b5078968da2c5a27f3d8b2b07cb3 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Fri, 29 Sep 2023 19:58:05 +0200 Subject: [PATCH 2/4] Remove redundant code relating to support for multiple VL versions and other things --- tools/generate_schema_wrapper.py | 140 ++++++++++++++----------------- tools/schemapi/utils.py | 42 +--------- 2 files changed, 65 insertions(+), 117 deletions(-) diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index 6325ebacd..02eea6931 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -29,10 +29,7 @@ resolve_references, ) -# Map of version name to github branch name. -SCHEMA_VERSION: Final = { - "vega-lite": {"v5": "v5.14.1"}, -} +SCHEMA_VERSION: Final = "v5.14.1" reLink = re.compile(r"(?<=\[)([^\]]+)(?=\]\([^\)]+\))", re.M) reSpecial = re.compile(r"[*_]{2,3}|`", re.M) @@ -262,18 +259,17 @@ def schema_class(*args, **kwargs) -> str: return SchemaGenerator(*args, **kwargs).schema_class() -def schema_url(library: str, version: str) -> str: - version = SCHEMA_VERSION[library][version] - return SCHEMA_URL_TEMPLATE.format(library=library, version=version) +def schema_url(version: str = SCHEMA_VERSION) -> str: + return SCHEMA_URL_TEMPLATE.format(library="vega-lite", version=version) def download_schemafile( - library: str, version: str, schemapath: str, skip_download: bool = False + version: str, schemapath: str, skip_download: bool = False ) -> str: - url = schema_url(library, version) + url = schema_url(version=version) if not os.path.exists(schemapath): os.makedirs(schemapath) - filename = os.path.join(schemapath, "{library}-schema.json".format(library=library)) + filename = os.path.join(schemapath, "vega-lite-schema.json") if not skip_download: request.urlretrieve(url, filename) elif not os.path.exists(filename): @@ -576,73 +572,65 @@ def generate_vegalite_config_mixin(schemafile: str) -> Tuple[List[str], str]: def vegalite_main(skip_download: bool = False) -> None: - library = "vega-lite" - - for version in SCHEMA_VERSION[library]: - path = abspath(join(dirname(__file__), "..", "altair", "vegalite", version)) - schemapath = os.path.join(path, "schema") - schemafile = download_schemafile( - library=library, - version=version, - schemapath=schemapath, - skip_download=skip_download, - ) + version = SCHEMA_VERSION + path = abspath( + join(dirname(__file__), "..", "altair", "vegalite", version.split(".")[0]) + ) + schemapath = os.path.join(path, "schema") + schemafile = download_schemafile( + version=version, + schemapath=schemapath, + skip_download=skip_download, + ) - # Generate __init__.py file - outfile = join(schemapath, "__init__.py") - print("Writing {}".format(outfile)) - with open(outfile, "w", encoding="utf8") as f: - f.write("# ruff: noqa\n") - f.write("from .core import *\nfrom .channels import *\n") - f.write( - "SCHEMA_VERSION = {!r}\n" "".format(SCHEMA_VERSION[library][version]) - ) - f.write("SCHEMA_URL = {!r}\n" "".format(schema_url(library, version))) - - # Generate the core schema wrappers - outfile = join(schemapath, "core.py") - print("Generating\n {}\n ->{}".format(schemafile, outfile)) - file_contents = generate_vegalite_schema_wrapper(schemafile) - with open(outfile, "w", encoding="utf8") as f: - f.write(file_contents) - - # Generate the channel wrappers - outfile = join(schemapath, "channels.py") - print("Generating\n {}\n ->{}".format(schemafile, outfile)) - code = generate_vegalite_channel_wrappers(schemafile, version=version) - with open(outfile, "w", encoding="utf8") as f: - f.write(code) - - # generate the mark mixin - if version == "v2": - markdefs = {"Mark": "MarkDef"} - else: - markdefs = { - k: k + "Def" for k in ["Mark", "BoxPlot", "ErrorBar", "ErrorBand"] - } - outfile = join(schemapath, "mixins.py") - print("Generating\n {}\n ->{}".format(schemafile, outfile)) - mark_imports, mark_mixin = generate_vegalite_mark_mixin(schemafile, markdefs) - config_imports, config_mixin = generate_vegalite_config_mixin(schemafile) - try_except_imports = [ - "if sys.version_info >= (3, 11):", - " from typing import Self", - "else:", - " from typing_extensions import Self", - ] - stdlib_imports = ["import sys"] - imports = sorted(set(mark_imports + config_imports)) - with open(outfile, "w", encoding="utf8") as f: - f.write(HEADER) - f.write("\n".join(stdlib_imports)) - f.write("\n\n") - f.write("\n".join(imports)) - f.write("\n\n") - f.write("\n".join(try_except_imports)) - f.write("\n\n\n") - f.write(mark_mixin) - f.write("\n\n\n") - f.write(config_mixin) + # Generate __init__.py file + outfile = join(schemapath, "__init__.py") + print("Writing {}".format(outfile)) + with open(outfile, "w", encoding="utf8") as f: + f.write("# ruff: noqa\n") + f.write("from .core import *\nfrom .channels import *\n") + f.write(f"SCHEMA_VERSION = '{version}'\n") + f.write("SCHEMA_URL = {!r}\n" "".format(schema_url(version))) + + # Generate the core schema wrappers + outfile = join(schemapath, "core.py") + print("Generating\n {}\n ->{}".format(schemafile, outfile)) + file_contents = generate_vegalite_schema_wrapper(schemafile) + with open(outfile, "w", encoding="utf8") as f: + f.write(file_contents) + + # Generate the channel wrappers + outfile = join(schemapath, "channels.py") + print("Generating\n {}\n ->{}".format(schemafile, outfile)) + code = generate_vegalite_channel_wrappers(schemafile, version=version) + with open(outfile, "w", encoding="utf8") as f: + f.write(code) + + # generate the mark mixin + markdefs = {k: k + "Def" for k in ["Mark", "BoxPlot", "ErrorBar", "ErrorBand"]} + outfile = join(schemapath, "mixins.py") + print("Generating\n {}\n ->{}".format(schemafile, outfile)) + mark_imports, mark_mixin = generate_vegalite_mark_mixin(schemafile, markdefs) + config_imports, config_mixin = generate_vegalite_config_mixin(schemafile) + try_except_imports = [ + "if sys.version_info >= (3, 11):", + " from typing import Self", + "else:", + " from typing_extensions import Self", + ] + stdlib_imports = ["import sys"] + imports = sorted(set(mark_imports + config_imports)) + with open(outfile, "w", encoding="utf8") as f: + f.write(HEADER) + f.write("\n".join(stdlib_imports)) + f.write("\n\n") + f.write("\n".join(imports)) + f.write("\n\n") + f.write("\n".join(try_except_imports)) + f.write("\n\n\n") + f.write(mark_mixin) + f.write("\n\n\n") + f.write(config_mixin) def main() -> None: diff --git a/tools/schemapi/utils.py b/tools/schemapi/utils.py index 320d6d109..73df11e66 100644 --- a/tools/schemapi/utils.py +++ b/tools/schemapi/utils.py @@ -136,12 +136,7 @@ class SchemaInfo: def __init__( self, schema: dict, rootschema: Optional[dict] = None ) -> None: - if hasattr(schema, "_schema"): - if hasattr(schema, "_rootschema"): - schema, rootschema = schema._schema, schema._rootschema - else: - schema, rootschema = schema._schema, schema._schema - elif not rootschema: + if not rootschema: rootschema = schema self.raw_schema = schema self.rootschema = rootschema @@ -235,11 +230,6 @@ def medium_description(self) -> str: ) return "any" - @property - def long_description(self) -> str: - # TODO - return "Long description including arguments and their types" - @property def properties(self) -> SchemaProperties: return SchemaProperties( @@ -372,36 +362,6 @@ def is_value(self) -> bool: def is_array(self) -> bool: return self.type == "array" - def schema_type( - self, - ) -> Literal["empty", "anyOf", "oneOf", "allOf", "object", "array", "value"]: - if self.is_empty(): - return "empty" - elif self.is_compound(): - for key in ["anyOf", "oneOf", "allOf"]: - if key in self.schema: - return key - else: - raise ValueError("Unclear why schema.is_compound() is True") - elif self.is_object(): - return "object" - elif self.is_array(): - return "array" - elif self.is_value(): - return "value" - else: - raise ValueError("Unknown type with keys {}".format(self.schema)) - - def property_name_map(self) -> Dict[str, str]: - """ - Return a mapping of schema property names to valid Python attribute names - - Only properties which are not valid Python identifiers will be included in - the dictionary. - """ - pairs = [(prop, get_valid_identifier(prop)) for prop in self.properties] - return {prop: val for prop, val in pairs if prop != val} - def indent_arglist( args: List[str], indent_level: int, width: int = 100, lstrip: bool = True From 26fa128cd481650041a49c59555aa00bcd0e2a79 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Fri, 29 Sep 2023 20:10:19 +0200 Subject: [PATCH 3/4] Rename codegen._get_args to get_args and use a dataclass as return value --- altair/vegalite/v5/api.py | 2 +- tools/generate_schema_wrapper.py | 12 ++--- tools/schemapi/codegen.py | 75 ++++++++++++++++++++------------ 3 files changed, 55 insertions(+), 34 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index c04645f00..a0b4b91bb 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -2859,7 +2859,7 @@ class RepeatChart(TopLevelMixin, core.TopLevelRepeatSpec): # Because TopLevelRepeatSpec is defined as a union as of Vega-Lite schema 4.9, # we set the arguments explicitly here. - # TODO: Should we instead use tools/schemapi/codegen._get_args? + # TODO: Should we instead use tools/schemapi/codegen.get_args? @utils.use_signature(core.TopLevelRepeatSpec) def __init__( self, diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index 02eea6931..288cb7f8d 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -519,16 +519,16 @@ def generate_vegalite_mark_mixin( info = SchemaInfo({"$ref": "#/definitions/" + mark_def}, rootschema=schema) # adapted from SchemaInfo.init_code - nonkeyword, required, kwds, invalid_kwds, additional = codegen._get_args(info) - required -= {"type"} - kwds -= {"type"} + arg_info = codegen.get_args(info) + arg_info.required -= {"type"} + arg_info.kwds -= {"type"} def_args = ["self"] + [ - "{}=Undefined".format(p) for p in (sorted(required) + sorted(kwds)) + "{}=Undefined".format(p) for p in (sorted(arg_info.required) + sorted(arg_info.kwds)) ] - dict_args = ["{0}={0}".format(p) for p in (sorted(required) + sorted(kwds))] + dict_args = ["{0}={0}".format(p) for p in (sorted(arg_info.required) + sorted(arg_info.kwds))] - if additional or invalid_kwds: + if arg_info.additional or arg_info.invalid_kwds: def_args.append("**kwds") dict_args.append("**kwds") diff --git a/tools/schemapi/codegen.py b/tools/schemapi/codegen.py index acf99d88b..f24f909f0 100644 --- a/tools/schemapi/codegen.py +++ b/tools/schemapi/codegen.py @@ -1,39 +1,50 @@ """Code generation utilities""" -from .utils import SchemaInfo, is_valid_identifier, indent_docstring, indent_arglist - -import textwrap import re +import textwrap +from typing import Tuple, Set +from dataclasses import dataclass + +from .utils import SchemaInfo, is_valid_identifier, indent_docstring, indent_arglist class CodeSnippet: """Object whose repr() is a string of code""" - def __init__(self, code): + def __init__(self, code: str): self.code = code - def __repr__(self): + def __repr__(self) -> str: return self.code -def _get_args(info): +@dataclass +class ArgInfo: + nonkeyword: bool + required: Set[str] + kwds: Set[str] + invalid_kwds: Set[str] + additional: bool + + +def get_args(info: SchemaInfo) -> ArgInfo: """Return the list of args & kwds for building the __init__ function""" # TODO: - set additional properties correctly # - handle patternProperties etc. - required = set() - kwds = set() - invalid_kwds = set() + required: Set[str] = set() + kwds: Set[str] = set() + invalid_kwds: Set[str] = set() # TODO: specialize for anyOf/oneOf? if info.is_allOf(): # recursively call function on all children - arginfo = [_get_args(child) for child in info.allOf] - nonkeyword = all(args[0] for args in arginfo) - required = set.union(set(), *(args[1] for args in arginfo)) - kwds = set.union(set(), *(args[2] for args in arginfo)) + arginfo = [get_args(child) for child in info.allOf] + nonkeyword = all(args.nonkeyword for args in arginfo) + required = set.union(set(), *(args.required for args in arginfo)) + kwds = set.union(set(), *(args.kwds for args in arginfo)) kwds -= required - invalid_kwds = set.union(set(), *(args[3] for args in arginfo)) - additional = all(args[4] for args in arginfo) + invalid_kwds = set.union(set(), *(args.invalid_kwds for args in arginfo)) + additional = all(args.additional for args in arginfo) elif info.is_empty() or info.is_compound(): nonkeyword = True additional = True @@ -53,7 +64,13 @@ def _get_args(info): else: raise ValueError("Schema object not understood") - return (nonkeyword, required, kwds, invalid_kwds, additional) + return ArgInfo( + nonkeyword=nonkeyword, + required=required, + kwds=kwds, + invalid_kwds=invalid_kwds, + additional=additional, + ) class SchemaGenerator: @@ -158,7 +175,7 @@ def docstring(self, indent=0): # TODO: add a general description at the top, derived from the schema. # for example, a non-object definition should list valid type, enum # values, etc. - # TODO: use _get_args here for more information on allOf objects + # TODO: use get_args here for more information on allOf objects info = SchemaInfo(self.schema, self.rootschema) doc = ["{} schema wrapper".format(self.classname), "", info.medium_description] if info.description: @@ -172,9 +189,13 @@ def docstring(self, indent=0): doc = [line for line in doc if ":raw-html:" not in line] if info.properties: - nonkeyword, required, kwds, invalid_kwds, additional = _get_args(info) + arg_info = get_args(info) doc += ["", "Parameters", "----------", ""] - for prop in sorted(required) + sorted(kwds) + sorted(invalid_kwds): + for prop in ( + sorted(arg_info.required) + + sorted(arg_info.kwds) + + sorted(arg_info.invalid_kwds) + ): propinfo = info.properties[prop] doc += [ "{} : {}".format(prop, propinfo.short_description), @@ -189,30 +210,30 @@ def docstring(self, indent=0): def init_code(self, indent=0): """Return code suitable for the __init__ function of a Schema class""" info = SchemaInfo(self.schema, rootschema=self.rootschema) - nonkeyword, required, kwds, invalid_kwds, additional = _get_args(info) + arg_info = get_args(info) nodefault = set(self.nodefault) - required -= nodefault - kwds -= nodefault + arg_info.required -= nodefault + arg_info.kwds -= nodefault args = ["self"] super_args = [] - self.init_kwds = sorted(kwds) + self.init_kwds = sorted(arg_info.kwds) if nodefault: args.extend(sorted(nodefault)) - elif nonkeyword: + elif arg_info.nonkeyword: args.append("*args") super_args.append("*args") - args.extend("{}=Undefined".format(p) for p in sorted(required) + sorted(kwds)) + args.extend("{}=Undefined".format(p) for p in sorted(arg_info.required) + sorted(arg_info.kwds)) super_args.extend( "{0}={0}".format(p) - for p in sorted(nodefault) + sorted(required) + sorted(kwds) + for p in sorted(nodefault) + sorted(arg_info.required) + sorted(arg_info.kwds) ) - if additional: + if arg_info.additional: args.append("**kwds") super_args.append("**kwds") From 6a7a2d36ee24e8ac9652267343ac2094470cfaf1 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Fri, 29 Sep 2023 20:30:41 +0200 Subject: [PATCH 4/4] Remaining type hints --- tools/generate_schema_wrapper.py | 19 +++++--- tools/schemapi/codegen.py | 83 ++++++++++++++++++++------------ tools/schemapi/utils.py | 30 ++++++------ 3 files changed, 78 insertions(+), 54 deletions(-) diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index 288cb7f8d..c692b9e2c 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -311,7 +311,7 @@ def recursive_dict_update(schema: dict, root: dict, def_dict: dict) -> None: def get_field_datum_value_defs(propschema: SchemaInfo, root: dict) -> dict: - def_dict = {k: None for k in ("field", "datum", "value")} + def_dict: Dict[str, Optional[str]] = {k: None for k in ("field", "datum", "value")} schema = propschema.schema if propschema.is_reference() and "properties" in schema: if "field" in schema["properties"]: @@ -381,13 +381,14 @@ def generate_vegalite_schema_wrapper(schema_file: str) -> str: for name, schema in definitions.items(): graph[name] = [] - for child in schema.subclasses(): - child = get_valid_identifier(child) - graph[name].append(child) - child = definitions[child] + for child_name in schema.subclasses(): + child_name = get_valid_identifier(child_name) + graph[name].append(child_name) + child: SchemaGenerator = definitions[child_name] if child.basename == basename: child.basename = [name] else: + assert isinstance(child.basename, list) child.basename.append(name) contents = [ @@ -524,9 +525,13 @@ def generate_vegalite_mark_mixin( arg_info.kwds -= {"type"} def_args = ["self"] + [ - "{}=Undefined".format(p) for p in (sorted(arg_info.required) + sorted(arg_info.kwds)) + "{}=Undefined".format(p) + for p in (sorted(arg_info.required) + sorted(arg_info.kwds)) + ] + dict_args = [ + "{0}={0}".format(p) + for p in (sorted(arg_info.required) + sorted(arg_info.kwds)) ] - dict_args = ["{0}={0}".format(p) for p in (sorted(arg_info.required) + sorted(arg_info.kwds))] if arg_info.additional or arg_info.invalid_kwds: def_args.append("**kwds") diff --git a/tools/schemapi/codegen.py b/tools/schemapi/codegen.py index f24f909f0..3f97bdf2f 100644 --- a/tools/schemapi/codegen.py +++ b/tools/schemapi/codegen.py @@ -1,10 +1,16 @@ """Code generation utilities""" import re import textwrap -from typing import Tuple, Set +from typing import Set, Final, Optional, List, Iterable, Union from dataclasses import dataclass -from .utils import SchemaInfo, is_valid_identifier, indent_docstring, indent_arglist +from .utils import ( + SchemaInfo, + is_valid_identifier, + indent_docstring, + indent_arglist, + SchemaProperties, +) class CodeSnippet: @@ -84,7 +90,7 @@ class SchemaGenerator: The dictionary defining the schema class rootschema : dict (optional) The root schema for the class - basename : string or tuple (default: "SchemaBase") + basename : string or list of strings (default: "SchemaBase") The name(s) of the base class(es) to use in the class definition schemarepr : CodeSnippet or object, optional An object whose repr will be used in the place of the explicit schema. @@ -109,47 +115,51 @@ class {classname}({basename}): ''' ) - init_template = textwrap.dedent( + init_template: Final = textwrap.dedent( """ def __init__({arglist}): super({classname}, self).__init__({super_arglist}) """ ).lstrip() - def _process_description(self, description): + def _process_description(self, description: str): return description def __init__( self, - classname, - schema, - rootschema=None, - basename="SchemaBase", - schemarepr=None, - rootschemarepr=None, - nodefault=(), - haspropsetters=False, + classname: str, + schema: dict, + rootschema: Optional[dict] = None, + basename: Union[str, List[str]] = "SchemaBase", + schemarepr: Optional[object] = None, + rootschemarepr: Optional[object] = None, + nodefault: Optional[List[str]] = None, + haspropsetters: bool = False, **kwargs, - ): + ) -> None: self.classname = classname self.schema = schema self.rootschema = rootschema self.basename = basename self.schemarepr = schemarepr self.rootschemarepr = rootschemarepr - self.nodefault = nodefault + self.nodefault = nodefault or () self.haspropsetters = haspropsetters self.kwargs = kwargs - def subclasses(self): + def subclasses(self) -> List[str]: """Return a list of subclass names, if any.""" info = SchemaInfo(self.schema, self.rootschema) return [child.refname for child in info.anyOf if child.is_reference()] - def schema_class(self): + def schema_class(self) -> str: """Generate code for a schema class""" - rootschema = self.rootschema if self.rootschema is not None else self.schema - schemarepr = self.schemarepr if self.schemarepr is not None else self.schema + rootschema: dict = ( + self.rootschema if self.rootschema is not None else self.schema + ) + schemarepr: object = ( + self.schemarepr if self.schemarepr is not None else self.schema + ) rootschemarepr = self.rootschemarepr if rootschemarepr is None: if rootschema is self.schema: @@ -171,7 +181,7 @@ def schema_class(self): **self.kwargs, ) - def docstring(self, indent=0): + def docstring(self, indent: int = 0) -> str: # TODO: add a general description at the top, derived from the schema. # for example, a non-object definition should list valid type, enum # values, etc. @@ -207,7 +217,7 @@ def docstring(self, indent=0): doc += [""] return indent_docstring(doc, indent_level=indent, width=100, lstrip=True) - def init_code(self, indent=0): + def init_code(self, indent: int = 0) -> str: """Return code suitable for the __init__ function of a Schema class""" info = SchemaInfo(self.schema, rootschema=self.rootschema) arg_info = get_args(info) @@ -216,8 +226,8 @@ def init_code(self, indent=0): arg_info.required -= nodefault arg_info.kwds -= nodefault - args = ["self"] - super_args = [] + args: List[str] = ["self"] + super_args: List[str] = [] self.init_kwds = sorted(arg_info.kwds) @@ -227,10 +237,15 @@ def init_code(self, indent=0): args.append("*args") super_args.append("*args") - args.extend("{}=Undefined".format(p) for p in sorted(arg_info.required) + sorted(arg_info.kwds)) + args.extend( + "{}=Undefined".format(p) + for p in sorted(arg_info.required) + sorted(arg_info.kwds) + ) super_args.extend( "{0}={0}".format(p) - for p in sorted(nodefault) + sorted(arg_info.required) + sorted(arg_info.kwds) + for p in sorted(nodefault) + + sorted(arg_info.required) + + sorted(arg_info.kwds) ) if arg_info.additional: @@ -261,9 +276,9 @@ def init_code(self, indent=0): "null": "None", } - def get_args(self, si): + def get_args(self, si: SchemaInfo) -> List[str]: contents = ["self"] - props = [] + props: Union[List[str], SchemaProperties] = [] if si.is_anyOf(): props = sorted({p for si_sub in si.anyOf for p in si_sub.properties}) elif si.properties: @@ -296,7 +311,9 @@ def get_args(self, si): return contents - def get_signature(self, attr, sub_si, indent, has_overload=False): + def get_signature( + self, attr: str, sub_si: SchemaInfo, indent: int, has_overload: bool = False + ) -> List[str]: lines = [] if has_overload: lines.append("@overload # type: ignore[no-overload-impl]") @@ -305,14 +322,16 @@ def get_signature(self, attr, sub_si, indent, has_overload=False): lines.append(indent * " " + "...\n") return lines - def setter_hint(self, attr, indent): + def setter_hint(self, attr: str, indent: int) -> List[str]: si = SchemaInfo(self.schema, self.rootschema).properties[attr] if si.is_anyOf(): return self._get_signature_any_of(si, attr, indent) else: return self.get_signature(attr, si, indent) - def _get_signature_any_of(self, si: SchemaInfo, attr, indent): + def _get_signature_any_of( + self, si: SchemaInfo, attr: str, indent: int + ) -> List[str]: signatures = [] for sub_si in si.anyOf: if sub_si.is_anyOf(): @@ -324,7 +343,7 @@ def _get_signature_any_of(self, si: SchemaInfo, attr, indent): ) return list(flatten(signatures)) - def method_code(self, indent=0): + def method_code(self, indent: int = 0) -> Optional[str]: """Return code to assist setter methods""" if not self.haspropsetters: return None @@ -334,7 +353,7 @@ def method_code(self, indent=0): return ("\n" + indent * " ").join(type_hints) -def flatten(container): +def flatten(container: Iterable) -> Iterable: """Flatten arbitrarily flattened list From https://stackoverflow.com/a/10824420 diff --git a/tools/schemapi/utils.py b/tools/schemapi/utils.py index 73df11e66..7a3d27408 100644 --- a/tools/schemapi/utils.py +++ b/tools/schemapi/utils.py @@ -4,7 +4,7 @@ import re import textwrap import urllib -from typing import Final, Optional, List, Dict, Literal, Union +from typing import Final, Optional, List, Dict, Any from .schemapi import _resolve_references as resolve_references @@ -93,7 +93,10 @@ class SchemaProperties: """A wrapper for properties within a schema""" def __init__( - self, properties: dict, schema: dict, rootschema: Optional[dict] = None + self, + properties: Dict[str, Any], + schema: dict, + rootschema: Optional[dict] = None, ) -> None: self._properties = properties self._schema = schema @@ -134,7 +137,7 @@ class SchemaInfo: """A wrapper for inspecting a JSON schema""" def __init__( - self, schema: dict, rootschema: Optional[dict] = None + self, schema: Dict[str, Any], rootschema: Optional[Dict[str, Any]] = None ) -> None: if not rootschema: rootschema = schema @@ -186,11 +189,7 @@ def short_description(self) -> str: @property def medium_description(self) -> str: - if self.is_list(): - return "[{0}]".format( - ", ".join(self.child(s).short_description for s in self.schema) - ) - elif self.is_empty(): + if self.is_empty(): return "Any" elif self.is_enum(): return "enum({})".format(", ".join(map(repr, self.enum))) @@ -229,6 +228,10 @@ def medium_description(self) -> str: stacklevel=1, ) return "any" + else: + raise ValueError( + "No medium_description available for this schema for schema" + ) @property def properties(self) -> SchemaProperties: @@ -259,19 +262,19 @@ def type(self) -> Optional[str]: return self.schema.get("type", None) @property - def anyOf(self) -> list: + def anyOf(self) -> List["SchemaInfo"]: return [self.child(s) for s in self.schema.get("anyOf", [])] @property - def oneOf(self) -> list: + def oneOf(self) -> List["SchemaInfo"]: return [self.child(s) for s in self.schema.get("oneOf", [])] @property - def allOf(self) -> list: + def allOf(self) -> List["SchemaInfo"]: return [self.child(s) for s in self.schema.get("allOf", [])] @property - def not_(self) -> dict: + def not_(self) -> "SchemaInfo": return self.child(self.schema.get("not", {})) @property @@ -314,9 +317,6 @@ def _get_description(self, include_sublevels: bool = False) -> str: desc = sub_desc return desc - def is_list(self) -> bool: - return isinstance(self.schema, list) - def is_reference(self) -> bool: return "$ref" in self.raw_schema