Skip to content

Commit

Permalink
Fix some type hints in test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
knutwannheden committed Jan 6, 2025
1 parent 4483ec5 commit 59d1373
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions rewrite/rewrite/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
from dataclasses import dataclass, field
from io import StringIO
from pathlib import Path
from typing import Optional, Callable, Iterable, List, TypeVar
from typing import Optional, Callable, Iterable, List, TypeVar, Any, cast
from uuid import UUID

from rewrite import InMemoryExecutionContext, ParserInput, ParserBuilder, random_id, ParseError, ParseExceptionResult, \
ExecutionContext, Recipe, TreeVisitor, SourceFile
from rewrite.execution import InMemoryLargeSourceSet
from rewrite.python import CompilationUnit
from rewrite.python.parser import PythonParserBuilder


S = TypeVar('S', bound=SourceFile)

@dataclass(frozen=True, eq=False)
Expand Down Expand Up @@ -49,10 +47,10 @@ def after(self) -> Optional[Callable[[str], str]]:
def source_path(self) -> Optional[Path]:
return self._source_path

_after_recipe: Callable[[S], None] = lambda _: None
_after_recipe: Optional[Callable[[S], None]] = lambda _: None

@property
def after_recipe(self) -> Callable[[S], None]:
def after_recipe(self) -> Optional[Callable[[S], None]]:
return self._after_recipe


Expand All @@ -66,16 +64,16 @@ def get_recipe_list(self) -> List[Recipe]:

@dataclass(frozen=True, eq=False)
class RecipeSpec:
_recipe: Recipe = None
_recipe: Optional[Recipe] = None

@property
def recipe(self) -> Recipe:
def recipe(self) -> Optional[Recipe]:
return self._recipe

def with_recipe(self, recipe: Recipe) -> RecipeSpec:
return self if recipe is self._recipe else RecipeSpec(recipe)

def with_recipes(self, *recipes: Recipe):
def with_recipes(self, *recipes: Recipe) -> RecipeSpec:
return RecipeSpec(CompositeRecipe(recipes))

_parsers: Iterable[ParserBuilder] = field(default_factory=list)
Expand All @@ -88,7 +86,7 @@ def with_parsers(self, parsers: Iterable[ParserBuilder]) -> RecipeSpec:
return self if parsers is self._parsers else RecipeSpec(self._recipe, parsers)


def rewrite_run(*source_specs: Iterable[SourceSpec], spec: RecipeSpec = None):
def rewrite_run(*source_specs: Iterable[SourceSpec], spec: Optional[RecipeSpec] = None) -> None:
from rewrite_remote import RemotingContext, RemotePrinterFactory
from rewrite_remote.server import register_remoting_factories
remoting_context = RemotingContext()
Expand All @@ -114,13 +112,13 @@ def rewrite_run(*source_specs: Iterable[SourceSpec], spec: RecipeSpec = None):
for source_file in parser.parse_inputs(
[ParserInput(source_path, None, True, lambda: StringIO(source_spec.before))], None, ctx):
if isinstance(source_file, ParseError):
assert False, f'Parser threw an exception:\n%{source_file.markers.find_first(ParseExceptionResult).message}'
assert False, f'Parser threw an exception:\n%{source_file.markers.find_first(ParseExceptionResult).message}' # type: ignore
remoting_context.client.reset()
assert source_file.print_all() == source_spec.before

spec_by_source_file[source_file] = source_spec

if spec:
if spec and spec.recipe:
recipe = spec.recipe
before = InMemoryLargeSourceSet(list(spec_by_source_file.keys()))
result = recipe.run(before, ctx)
Expand All @@ -141,24 +139,26 @@ def rewrite_run(*source_specs: Iterable[SourceSpec], spec: RecipeSpec = None):
remoting_context.close()


def python(before: str, after: str = None, after_recipe: Callable[[CompilationUnit], None] = lambda s: None) -> list[SourceSpec]:
S2 = TypeVar('S2', bound=SourceFile)

def python(before: str, after: Optional[str] = None, after_recipe: Optional[Callable[[S2], None]] = lambda s: None) -> list[SourceSpec]:
return [SourceSpec(
random_id(),
PythonParserBuilder(),
textwrap.dedent(before),
None if after is None else lambda _: textwrap.dedent(after),
None,
after_recipe
cast(Optional[Callable[[S], None]], after_recipe)
)]


def from_visitor(visitor: TreeVisitor[any, any]) -> Recipe:
def from_visitor(visitor: TreeVisitor[Any, Any]) -> Recipe:
return AdHocRecipe(visitor)


@dataclass(frozen=True)
class AdHocRecipe(Recipe):
visitor: TreeVisitor[any, any]
visitor: TreeVisitor[Any, Any]

def get_visitor(self) -> TreeVisitor[any, any]:
def get_visitor(self) -> TreeVisitor[Any, Any]:
return self.visitor

0 comments on commit 59d1373

Please sign in to comment.