diff --git a/impall.py b/impall.py index 441af1f..03c1542 100755 --- a/impall.py +++ b/impall.py @@ -67,8 +67,10 @@ class ImpAllTest(impall.ImpAllTest): import os import sys import traceback +import typing as t import unittest import warnings +from types import ModuleType __author__ = 'Tom Ritchford ' __all__ = 'ImpAllTest', 'path_to_import' @@ -129,17 +131,17 @@ class ImpAllTest(unittest.TestCase): WARNINGS_ACTION = 'default' @functools.cached_property - def _exc(self): + def _exc(self) -> t.Callable[[t.Any], bool]: return _split_pattern(self.EXCLUDE or ()) @functools.cached_property - def _inc(self): + def _inc(self) -> t.Callable[[t.Any], bool]: if self.INCLUDE is None: return lambda s: True else: return _split_pattern(self.INCLUDE) - def test_all(self): + def test_all(self) -> None: successes, failures = self.impall() self.assertTrue(successes or failures, 'No tests were found') expected = sorted(_split(self.FAILING)) @@ -164,18 +166,19 @@ def test_all(self): self.assertTrue(not failed, 'Some tests failed') self.assertTrue(not failed_to_fail, 'Some tests failed to fail') - def impall(self): - successes, failures = [], [] + def impall(self) -> t.Tuple[t.List[str], t.List[t.Tuple[str, str]]]: + successes: t.List[str] = [] + failures: t.List[t.Tuple[str, str]] = [] paths = _split(self.PATHS or path_to_import(os.getcwd())[0]) - warnings.simplefilter(self.WARNINGS_ACTION) + warnings.simplefilter(self.WARNINGS_ACTION) # type: ignore[arg-type] for file in self._all_imports(paths): self._import(file, successes, failures) - warnings.filters.pop(0) + warnings.filters.pop(0) # type: ignore[attr-defined] return successes, failures - def _all_imports(self, paths): + def _all_imports(self, paths: t.Sequence[str]) -> t.Iterator[str]: for path in paths: for directory, sub_dirs, files in os.walk(path): if directory != path and not self._accept_dir(directory): @@ -189,7 +192,7 @@ def _all_imports(self, paths): if f.endswith('.py') and not _is_ignored(f): yield os.path.join(directory, f) - def _import(self, file, successes, failures): + def _import(self, file: str, successes: t.List[str], failures: t.List[t.Tuple[str, str]]) -> None: root, module = path_to_import(file) path = file[:-3] if file.endswith('.py') else file @@ -221,14 +224,14 @@ def _import(self, file, successes, failures): sys.modules.update(saved_modules) sys.path[:] = saved_path - def _accept_dir(self, directory): + def _accept_dir(self, directory: str) -> bool: if self.MODULES: return _is_python_dir(directory) return not _is_ignored(directory) @functools.lru_cache() -def path_to_import(path): +def path_to_import(path: str) -> t.Tuple[str, str]: """ Return a (path, module) pair that allows you to import the Python file or directory at location path @@ -242,20 +245,21 @@ def path_to_import(path): if path.endswith('.py'): path = path[:-3] - def isdir(p): + def isdir(p: str) -> bool: return os.path.isdir(p) and not os.path.exists(p + '.py') while not isdir(path) or _is_python_dir(path): path, part = os.path.split(path) if not part: - path and parts.append(path) + if path: + parts.append(path) break parts.append(part) return path, '.'.join(reversed(parts)) -def import_file(path): +def import_file(path: str) -> ModuleType: """ Given a path to a file or directory, imports it from the correct root and returns the module @@ -279,20 +283,20 @@ def import_file(path): _NO = 'NO_' -def _is_ignored(path): +def _is_ignored(path: str) -> bool: b = os.path.basename(path) return b.startswith('.') or ( b.startswith('__') and os.path.isdir(path) or b == '__init__.py' ) -def _is_python_dir(path): +def _is_python_dir(path: str) -> bool: """Return True if `path` is a directory containing an __init__.py file""" init = os.path.join(path, '__init__.py') return os.path.exists(init) and not _is_ignored(path) -def _split(s): +def _split(s: t.Union[str, t.Sequence[str]]) -> t.Sequence[str]: if not s: return [] if isinstance(s, str): @@ -300,12 +304,12 @@ def _split(s): return s -def _split_pattern(s): +def _split_pattern(s: t.Union[str, t.Sequence[str]]) -> t.Callable[[t.Any], bool]: parts = _split(s) return lambda x: any(fnmatch.fnmatch(x, p) for p in parts) -def report(): +def report() -> None: """Test all files in a directory from the command line""" args = _parse_args() test_case = ImpAllTest() @@ -318,7 +322,7 @@ def report(): default = getattr(test_case, attr, _NO) if default is not _NO and (isinstance(value, bool) or value): - if isinstance(default, (list, tuple)): + if isinstance(default, (list, tuple)) and isinstance(value, str): value = value.split(ENV_SEPARATOR) setattr(test_case, attr, value) @@ -328,18 +332,20 @@ def report(): print() if failures: - failures = ['%s (%s)' % (m, e) for (m, e) in failures] - print('Failures', *failures, sep='\n ', file=sys.stderr) + fail = ['%s (%s)' % (m, e) for (m, e) in failures] + print('Failures', *fail, sep='\n ', file=sys.stderr) print(file=sys.stderr) -def _parse_args(): +def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=_USAGE) parser.add_argument('paths', nargs='*', default=[os.getcwd()]) + kwds: t.Dict[str, t.Any] for prop in PROPERTIES: default = getattr(ImpAllTest, prop) help = globals()[prop] + if isinstance(default, bool): kwds = {'action': 'store_true'} if default: diff --git a/pyproject.toml b/pyproject.toml index eba0de9..dd7c377 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,9 @@ line-length = 88 [tool.ruff.format] quote-style = "single" + +[tool.mypy] +strict = true [build-system] build-backend = "poetry.core.masonry.api" requires = ["poetry-core"]